diff --git a/site/zh-cn/agents/tutorials/0_intro_rl.ipynb b/site/zh-cn/agents/tutorials/0_intro_rl.ipynb
index 9dc9e23bba..a31b10bc3f 100644
--- a/site/zh-cn/agents/tutorials/0_intro_rl.ipynb
+++ b/site/zh-cn/agents/tutorials/0_intro_rl.ipynb
@@ -6,7 +6,7 @@
"id": "I1JiGtmRbLVp"
},
"source": [
- "##### Copyright 2021 The TF-Agents Authors."
+ "##### Copyright 2023 The TF-Agents Authors."
]
},
{
@@ -100,7 +100,9 @@
"\n",
"Q-Learning 基于 Q 函数的概念。策略 $\\pi$, $Q^{\\pi}(s, a)$ 的 Q 函数(又称状态-操作值函数)用于衡量通过首先采取操作 $a$、随后采取策略 $\\pi$,从状态 $s$ 获得的预期回报或折扣奖励总和。我们将最优 Q 函数 $Q^*(s, a)$ 定义为从观测值 $s$ 开始,先采取操作 $a$,随后采取最优策略所能获得的最大回报。最优 Q 函数遵循以下*贝尔曼*最优性方程:\n",
"\n",
+ "```\n",
"$\\begin{equation}Q^\\ast(s, a) = \\mathbb{E}[ r + \\gamma \\max_{a'} Q^\\ast(s', a') ]\\end{equation}$\n",
+ "```\n",
"\n",
"这意味着,从状态 $s$ 和操作 $a$ 获得的最大回报等于即时奖励 $r$ 与通过遵循最优策略,随后直到片段结束所获得的回报(折扣因子为 $\\gamma$)的总和(即,来自下一个状态 $s'$ 的最高奖励)。期望是在即时奖励 $r$ 的分布以及可能的下一个状态 $s'$ 的基础上计算的。\n",
"\n",
@@ -110,9 +112,9 @@
"\n",
"对于大多数问题,将 $Q$ 函数表示为包含 $s$ 和 $a$ 每种组合的值的表是不切实际的。相反,我们训练一个函数逼近器(例如,带参数 $\\theta$ 的神经网络)来估算 Q 值,即 $Q(s, a; \\theta) \\approx Q^*(s, a)$。这可以通过在每个步骤 $i$ 使以下损失最小化来实现:\n",
"\n",
- "$\\begin{equation}L_i(\\theta_i) = \\mathbb{E}*{s, a, r, s'\\sim \\rho(.)} \\left[ (y_i - Q(s, a; \\theta_i))^2 \\right]\\end{equation}$,其中 $y_i = r + \\gamma \\max*{a'} Q(s', a'; \\theta_{i-1})$\n",
+ "$\\begin{equation}L_i(\\theta_i) = \\mathbb{E}{em0}{s, a, r, s'\\sim \\rho(.)} \\left[ (y_i - Q(s, a; \\theta_i))^2 \\right]\\end{equation}$,其中 $y_i = r + \\gamma \\max{/em0}{a'} Q(s', a'; \\theta_{i-1})$\n",
"\n",
- "此处,$y_i$ 称为 TD(时间差分)目标,而 $y_i - Q$ 称为 TD 误差。$\\rho$ 表示行为分布,即从环境中收集的转换 ${s, a, r, s'}$ 的分布。\n",
+ "此处,$y_i$ is 称为 TD(时间差分)目标,而 $y_i - Q$ 称为 TD 误差。$\\rho$ 表示行为分布,即从环境中收集的转换 ${s, a, r, s'}$ 的分布。\n",
"\n",
"注意,先前迭代 $\\theta_{i-1}$ 中的参数是固定的,不会更新。实际上,我们使用前几次迭代而不是最后一次迭代的网络参数快照。此副本称为*目标网络*。\n",
"\n",
@@ -120,7 +122,7 @@
"\n",
"### 经验回放\n",
"\n",
- "为了避免计算 DQN 损失的全期望,我们可以使用随机梯度下降算法将其最小化。如果仅使用最后一个转换 ${s, a, r, s'}$ 来计算损失,那么这会简化为标准 Q-Learning。\n",
+ "为了避免计算 DQN 损失的全期望,我们可以使用随机梯度下降算法将其最小化。如果仅使用最后的转换 ${s, a, r, s'}$ 计算损失,这将简化为标准 Q-Learning。\n",
"\n",
"Atari DQN 工作引入了一种称为“经验回放”的技术,可使网络更新更加稳定。在数据收集的每个时间步骤,转换都会添加到称为*回放缓冲区*的循环缓冲区中。然后,在训练过程中,我们不是仅仅使用最新的转换来计算损失及其梯度,而是使用从回放缓冲区中采样的转换的 mini-batch 来计算它们。这样做有两个优点:通过在许多更新中重用每个转换来提高数据效率,以及在批次中使用不相关的转换来提高稳定性。\n"
]
diff --git a/site/zh-cn/agents/tutorials/10_checkpointer_policysaver_tutorial.ipynb b/site/zh-cn/agents/tutorials/10_checkpointer_policysaver_tutorial.ipynb
index 898347d50d..4545d35347 100644
--- a/site/zh-cn/agents/tutorials/10_checkpointer_policysaver_tutorial.ipynb
+++ b/site/zh-cn/agents/tutorials/10_checkpointer_policysaver_tutorial.ipynb
@@ -6,7 +6,7 @@
"id": "W7rEsKyWcxmu"
},
"source": [
- "##### Copyright 2021 The TF-Agents Authors.\n"
+ "##### Copyright 2023 The TF-Agents Authors.\n"
]
},
{
diff --git a/site/zh-cn/agents/tutorials/1_dqn_tutorial.ipynb b/site/zh-cn/agents/tutorials/1_dqn_tutorial.ipynb
index 5ef5f890e2..c231fd32ca 100644
--- a/site/zh-cn/agents/tutorials/1_dqn_tutorial.ipynb
+++ b/site/zh-cn/agents/tutorials/1_dqn_tutorial.ipynb
@@ -6,7 +6,7 @@
"id": "klGNgWREsvQv"
},
"source": [
- "##### Copyright 2021 The TF-Agents Authors."
+ "##### Copyright 2023 The TF-Agents Authors."
]
},
{
@@ -40,12 +40,9 @@
"# 使用 TF-Agents 训练深度 Q 网络\n",
"\n",
"
"
]
diff --git a/site/zh-cn/agents/tutorials/2_environments_tutorial.ipynb b/site/zh-cn/agents/tutorials/2_environments_tutorial.ipynb
index 9747538d30..8d17f9d846 100644
--- a/site/zh-cn/agents/tutorials/2_environments_tutorial.ipynb
+++ b/site/zh-cn/agents/tutorials/2_environments_tutorial.ipynb
@@ -6,7 +6,7 @@
"id": "Ma19Ks2CTDbZ"
},
"source": [
- "##### Copyright 2021 The TF-Agents Authors."
+ "##### Copyright 2023 The TF-Agents Authors."
]
},
{
@@ -95,7 +95,6 @@
},
"outputs": [],
"source": [
- "!pip install \"gym>=0.21.0\"\n",
"!pip install tf-agents[reverb]\n"
]
},
diff --git a/site/zh-cn/agents/tutorials/3_policies_tutorial.ipynb b/site/zh-cn/agents/tutorials/3_policies_tutorial.ipynb
index 7855544030..596996f6ec 100644
--- a/site/zh-cn/agents/tutorials/3_policies_tutorial.ipynb
+++ b/site/zh-cn/agents/tutorials/3_policies_tutorial.ipynb
@@ -6,7 +6,7 @@
"id": "1Pi_B2cvdBiW"
},
"source": [
- "##### Copyright 2021 The TF-Agents Authors."
+ "##### Copyright 2023 The TF-Agents Authors."
]
},
{
@@ -40,12 +40,9 @@
"# 策略\n",
"\n",
""
]
diff --git a/site/zh-cn/agents/tutorials/4_drivers_tutorial.ipynb b/site/zh-cn/agents/tutorials/4_drivers_tutorial.ipynb
index 65ce704b39..04318a55a4 100644
--- a/site/zh-cn/agents/tutorials/4_drivers_tutorial.ipynb
+++ b/site/zh-cn/agents/tutorials/4_drivers_tutorial.ipynb
@@ -6,7 +6,7 @@
"id": "beObUOFyuRjT"
},
"source": [
- "##### Copyright 2021 The TF-Agents Authors."
+ "##### Copyright 2023 The TF-Agents Authors."
]
},
{
diff --git a/site/zh-cn/agents/tutorials/5_replay_buffers_tutorial.ipynb b/site/zh-cn/agents/tutorials/5_replay_buffers_tutorial.ipynb
index cbf4946b41..a8acd972b3 100644
--- a/site/zh-cn/agents/tutorials/5_replay_buffers_tutorial.ipynb
+++ b/site/zh-cn/agents/tutorials/5_replay_buffers_tutorial.ipynb
@@ -6,7 +6,7 @@
"id": "beObUOFyuRjT"
},
"source": [
- "##### Copyright 2021 The TF-Agents Authors."
+ "##### Copyright 2023 The TF-Agents Authors."
]
},
{
diff --git a/site/zh-cn/agents/tutorials/6_reinforce_tutorial.ipynb b/site/zh-cn/agents/tutorials/6_reinforce_tutorial.ipynb
index e7c53f2dbd..e994708395 100644
--- a/site/zh-cn/agents/tutorials/6_reinforce_tutorial.ipynb
+++ b/site/zh-cn/agents/tutorials/6_reinforce_tutorial.ipynb
@@ -6,7 +6,7 @@
"id": "klGNgWREsvQv"
},
"source": [
- "##### Copyright 2021 The TF-Agents Authors."
+ "##### Copyright 2023 The TF-Agents Authors."
]
},
{
@@ -40,12 +40,9 @@
"# REINFORCE 代理\n",
"\n",
""
]
diff --git a/site/zh-cn/agents/tutorials/7_SAC_minitaur_tutorial.ipynb b/site/zh-cn/agents/tutorials/7_SAC_minitaur_tutorial.ipynb
index 077c04dfce..691ed9cdf8 100644
--- a/site/zh-cn/agents/tutorials/7_SAC_minitaur_tutorial.ipynb
+++ b/site/zh-cn/agents/tutorials/7_SAC_minitaur_tutorial.ipynb
@@ -6,7 +6,7 @@
"id": "klGNgWREsvQv"
},
"source": [
- "**Copyright 2021 The TF-Agents Authors.**"
+ "**Copyright 2023 The TF-Agents Authors.**"
]
},
{
diff --git a/site/zh-cn/agents/tutorials/8_networks_tutorial.ipynb b/site/zh-cn/agents/tutorials/8_networks_tutorial.ipynb
index 7a8ca20eea..f3c2804faf 100644
--- a/site/zh-cn/agents/tutorials/8_networks_tutorial.ipynb
+++ b/site/zh-cn/agents/tutorials/8_networks_tutorial.ipynb
@@ -6,7 +6,7 @@
"id": "1Pi_B2cvdBiW"
},
"source": [
- "##### Copyright 2021 The TF-Agents Authors."
+ "##### Copyright 2023 The TF-Agents Authors."
]
},
{
diff --git a/site/zh-cn/agents/tutorials/9_c51_tutorial.ipynb b/site/zh-cn/agents/tutorials/9_c51_tutorial.ipynb
index b77d604781..bc86f09f27 100644
--- a/site/zh-cn/agents/tutorials/9_c51_tutorial.ipynb
+++ b/site/zh-cn/agents/tutorials/9_c51_tutorial.ipynb
@@ -6,7 +6,7 @@
"id": "klGNgWREsvQv"
},
"source": [
- "##### Copyright 2021 The TF-Agents Authors."
+ "##### Copyright 2023 The TF-Agents Authors."
]
},
{
@@ -40,12 +40,9 @@
"# DQN C51/Rainbow\n",
"\n",
""
]
diff --git a/site/zh-cn/agents/tutorials/bandits_tutorial.ipynb b/site/zh-cn/agents/tutorials/bandits_tutorial.ipynb
index fb494d0351..baecf2e257 100644
--- a/site/zh-cn/agents/tutorials/bandits_tutorial.ipynb
+++ b/site/zh-cn/agents/tutorials/bandits_tutorial.ipynb
@@ -6,7 +6,7 @@
"id": "klGNgWREsvQv"
},
"source": [
- "##### Copyright 2020 The TF-Agents Authors."
+ "##### Copyright 2023 The TF-Agents Authors."
]
},
{
diff --git a/site/zh-cn/agents/tutorials/intro_bandit.ipynb b/site/zh-cn/agents/tutorials/intro_bandit.ipynb
index 8697da3a8a..7cd31888f4 100644
--- a/site/zh-cn/agents/tutorials/intro_bandit.ipynb
+++ b/site/zh-cn/agents/tutorials/intro_bandit.ipynb
@@ -6,7 +6,7 @@
"id": "I1JiGtmRbLVp"
},
"source": [
- "##### Copyright 2021 The TF-Agents Authors."
+ "##### Copyright 2023 The TF-Agents Authors."
]
},
{
diff --git a/site/zh-cn/agents/tutorials/per_arm_bandits_tutorial.ipynb b/site/zh-cn/agents/tutorials/per_arm_bandits_tutorial.ipynb
index e0ed54a870..defeab987a 100644
--- a/site/zh-cn/agents/tutorials/per_arm_bandits_tutorial.ipynb
+++ b/site/zh-cn/agents/tutorials/per_arm_bandits_tutorial.ipynb
@@ -6,7 +6,7 @@
"id": "nPjtEgqN4SjA"
},
"source": [
- "##### Copyright 2021 The TF-Agents Authors."
+ "##### Copyright 2023 The TF-Agents Authors."
]
},
{
diff --git a/site/zh-cn/agents/tutorials/ranking_tutorial.ipynb b/site/zh-cn/agents/tutorials/ranking_tutorial.ipynb
index cf5bed7c7b..1449df6c4d 100644
--- a/site/zh-cn/agents/tutorials/ranking_tutorial.ipynb
+++ b/site/zh-cn/agents/tutorials/ranking_tutorial.ipynb
@@ -6,7 +6,7 @@
"id": "6tzp2bPEiK_S"
},
"source": [
- "##### Copyright 2022 The TF-Agents Authors."
+ "##### Copyright 2023 The TF-Agents Authors."
]
},
{
@@ -49,14 +49,10 @@
"### 开始\n",
"\n",
"\n"
]
},
diff --git a/site/zh-cn/community/mailing-lists.md b/site/zh-cn/community/mailing-lists.md
new file mode 100644
index 0000000000..01cfb66a3a
--- /dev/null
+++ b/site/zh-cn/community/mailing-lists.md
@@ -0,0 +1,42 @@
+# 邮寄的名单
+
+作为一个社区,我们通过公开邮寄名单展开许多协作。请注意,如果您在寻找 TensorFlow 使用帮助,[TensorFlow 论坛](https://discuss.tensorflow.org/)、[Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow) 和 [GitHub 议题](https://github.com/tensorflow/tensorflow/issues)是最理想的初始选择。要想每季度接收来自 TensorFlow 团队的动态汇总,请订阅 [TensorFlow 简报](https://services.google.com/fb/forms/tensorflow/)。
+
+## TensorFlow 通用名单和论坛
+
+- [宣布](https://groups.google.com/a/tensorflow.org/d/forum/announce) - 新版本的小批量公告。
+- [讨论](https://groups.google.com/a/tensorflow.org/d/forum/discuss) - 社区中关于 TensorFlow 的一般讨论。
+- [开发者](https://groups.google.com/a/tensorflow.org/d/forum/developers) - 为 TensorFlow 做出贡献的开发者的讨论。
+- [文档](https://discuss.tensorflow.org/tag/docs) - 为 TensorFlow 文档做贡献的讨论。有关语言特定的文档列表,请参阅[社区翻译](https://www.tensorflow.org/community/contribute/docs#community_translations)。
+- [测试](https://groups.google.com/a/tensorflow.org/d/forum/testing) - 有关 TensorFlow 2 测试的讨论和问题。
+
+## 项目的特定名单
+
+TensorFlow GitHub 组织内部的以下项目有专门用于各自社区的名单:
+
+- [hub](https://groups.google.com/a/tensorflow.org/d/forum/hub) – 围绕 [TensorFlow Hub](https://github.com/tensorflow/hub) 的讨论与合作。
+- [magenta-discuss](https://groups.google.com/a/tensorflow.org/d/forum/magenta-discuss) – 关于 [Magenta](https://magenta.tensorflow.org/) 发展和方向的一般讨论。
+- [tensor2tensor@tensorflow.org](https://groups.google.com/d/forum/tensor2tensor) - Tensor2Tensor 的讨论和同侪支持。
+- [tfjs-announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/d/forum/tfjs-announce) - 新 TensorFlow.js 版本的公告。
+- [tensor2tensor@tensorflow.org](https://groups.google.com/a/tensorflow.org/d/forum/tfjs) - Tensor2Tensor 的讨论和同侪支持。
+- [tensor2tensor@tensorflow.org](https://groups.google.com/a/tensorflow.org/d/forum/tflite) - Tensor2Tensor 的讨论和同侪支持。
+- [tensor2tensor@tensorflow.org](https://groups.google.com/a/tensorflow.org/d/forum/tfprobability) - Tensor2Tensor 的讨论和同侪支持。
+- [tfx](https://groups.google.com/a/tensorflow.org/forum/#!forum/tfx) – 围绕 [TensorFlow Extended (TFX)](https://www.tensorflow.org/tfx/) 的讨论与合作。
+- [tpu-users@tensorflow.org](https://groups.google.com/a/tensorflow.org/d/forum/tpu-users) - TPU 用户的社区讨论和支持。
+- [developers](https://groups.google.com/forum/#!forum/xla-dev) - 为 TensorFlow 做出贡献的开发者的讨论。
+
+## 特殊兴趣小组 (SIG)
+
+TensorFlow 的[特殊兴趣小组](https://github.com/tensorflow/community/tree/master/sigs) (SIG) 支持在特定的项目上重点开展社区协作。这些小组的成员通力合作,以构建和支持与 TensorFlow 相关的项目。尽管他们的归档是公开的,但是,不同的 SIG 都有自己的成员资格政策。
+
+- [addons](https://groups.google.com/a/tensorflow.org/d/forum/addons) – 支持 SIG Addons,负责符合稳定 API 要求的 TensorFlow 扩展程序。
+- [build](https://groups.google.com/a/tensorflow.org/d/forum/build) - 支持 SIG Build,用于 TensorFlow 的构建、分发和打包。
+- [io](https://groups.google.com/a/tensorflow.org/d/forum/io) – 支持 SIG IO,负责核心 TensorFlow 中未提供的文件系统和格式。
+- [jvm](https://groups.google.com/a/tensorflow.org/d/forum/jvm) – 支持 SIG JVM,为 TensorFlow 构建 Java 和 JVM 支持。
+- [keras](https://groups.google.com/forum/#!forum/keras-users) – Keras 用户邮寄名单,负责与 SIG Keras 相关的设计评审和讨论。
+- [micro](https://groups.google.com/a/tensorflow.org/d/forum/micro) – 支持 SIG Micro,专注于低功耗 TF Lite 部署。
+- [mlir](https://groups.google.com/a/tensorflow.org/d/forum/mlir) – 支持 SIG MLIR,围绕 MLIR(多层中间表示)开展协作。
+- [networking](https://groups.google.com/a/tensorflow.org/d/forum/networking) – 支持 SIG Networking,负责添加 gRPC 以外的网络协议。
+- [rust](https://groups.google.com/a/tensorflow.org/d/forum/rust) – 支持 SIG Rust,负责 Rust 语言绑定。
+- [swift](https://groups.google.com/a/tensorflow.org/d/forum/swift) – 支持 SIG Swift,负责开发 Swift for TensorFlow。
+- [tensorboard](https://groups.google.com/a/tensorflow.org/d/forum/tensorboard) – 支持 SIG TensorBoard,负责插件开发和其他贡献。
diff --git a/site/zh-cn/datasets/overview.ipynb b/site/zh-cn/datasets/overview.ipynb
index eab267e388..4a68ecb730 100644
--- a/site/zh-cn/datasets/overview.ipynb
+++ b/site/zh-cn/datasets/overview.ipynb
@@ -8,11 +8,11 @@
"source": [
"# TensorFlow Datasets\n",
"\n",
- "TFDS 提供了一组现成的数据集,适合与 TensorFlow、Jax 和其他机器学习框架配合使用。\n",
+ "TFDS provides a collection of ready-to-use datasets for use with TensorFlow, Jax, and other Machine Learning frameworks.\n",
"\n",
- "它可以确定地处理下载和准备数据并构造 `tf.data.Dataset`(或 `np.array`)。\n",
+ "It handles downloading and preparing the data deterministically and constructing a `tf.data.Dataset` (or `np.array`).\n",
"\n",
- "注:不要将 [TFDS](https://tensorflow.google.cn/datasets)(此库)与 `tf.data`(用于构建高效数据流水线的 TensorFlow API)混淆。TFDS 是 `tf.data` 的高级封装容器。如果您不熟悉此 API,建议您先阅读[官方 tf.data 指南](https://tensorflow.google.cn/guide/data)。\n"
+ "Note: Do not confuse [TFDS](https://tensorflow.google.cn/datasets) (this library) with `tf.data` (TensorFlow API to build efficient data pipelines). TFDS is a high level wrapper around `tf.data`. If you're not familiar with this API, we encourage you to read [the official tf.data guide](https://tensorflow.google.cn/guide/data) first.\n"
]
},
{
@@ -21,7 +21,7 @@
"id": "J8y9ZkLXmAZc"
},
"source": [
- "版权所有 2018 TensorFlow 数据集作者,以 Apache License, Version 2.0 授权"
+ "Copyright 2018 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0"
]
},
{
diff --git a/site/zh-cn/datasets/splits.md b/site/zh-cn/datasets/splits.md
index 189a55a5ae..9900a94d67 100644
--- a/site/zh-cn/datasets/splits.md
+++ b/site/zh-cn/datasets/splits.md
@@ -1,6 +1,6 @@
# 拆分和切片
-所有 TFDS 数据集都公开了可以在[目录](https://www.tensorflow.org/datasets/catalog/overview)中浏览的各种数据拆分(例如 `'train'`、`'test'`)。
+所有 TFDS 数据集都提供了不同的数据拆分(例如 `'train'`、`'test'`),可以在[目录](https://www.tensorflow.org/datasets/catalog/overview)中进行探索。除了 `all`(它是一个保留术语,表示所有拆分的并集,见下文),任何字母字符串都可以用作拆分名称。
除了“官方”数据集拆分之外,TFDS 还允许选择拆分的切片和各种组合。
@@ -19,10 +19,10 @@ ds = builder.as_dataset(split='test+train[:75%]')
拆分可以是:
-- **普通拆分**(`'train'`、`'test'`):所选拆分中的所有样本。
+- **普通拆分名称**(字符串,例如 `'train'`、`'test'`…):所选拆分中的所有样本。
- **切片**:切片与 [python 切片表示法](https://docs.python.org/3/library/stdtypes.html#common-sequence-operations)具有相同的语义。切片可以是:
- **绝对**(`'train[123:450]'`、`train[:4000]`):(请参阅下方注释了解有关读取顺序的注意事项)
- - **百分比**(`'train[:75%]'`、`'train[25%:75%]'`):将完整数据分成 100 个均匀切片。如果数据不能被 100 整除,则某些百分比可能包含附加样本。
+ - **百分比** (`'train[:75%]'`、`'train[25%:75%]'`):将完整数据分成均匀的切片。如果数据不能被整除,则某些百分比可能包含附加样本。支持小数百分比。
- **分片**(`train[:4shard]`、`train[4shard]`):选择请求的分片中的所有样本。(请参阅 `info.splits['train'].num_shards` 以获取拆分的分片数)
- **拆分联合**(`'train+test'`、`'train[:25%]+test'`):拆分将交错在一起。
- **完整数据集** (`'all'`):`'all'` 是一个与所有拆分的联合对应的特殊拆分名称(相当于 `'train+test+...'`)。
@@ -118,7 +118,7 @@ ds = tfds.load('my_dataset', split=split)
- `%`:百分比切片
- `shard`:分片切片
-`tfds.ReadInstruction` 也有一个舍入参数。如果数据集中的样本数量不能被 `100` 整除:
+`tfds.ReadInstruction` 也有一个舍入参数。如果数据集中的样本数量不能被整除:
- `rounding='closest'`(默认):剩余的样本会在百分比中分布,因此某些百分比可能包含附加样本。
- `rounding='pct1_dropremainder'`:剩余的样本会被丢弃,但这可保证所有百分比均包含完全相同数量的样本(例如:`len(5%) == 5 * len(1%)`)。
diff --git a/site/zh-cn/federated/tutorials/custom_federated_algorithms_1.ipynb b/site/zh-cn/federated/tutorials/custom_federated_algorithms_1.ipynb
index 6090a7685b..a743021014 100644
--- a/site/zh-cn/federated/tutorials/custom_federated_algorithms_1.ipynb
+++ b/site/zh-cn/federated/tutorials/custom_federated_algorithms_1.ipynb
@@ -49,8 +49,7 @@
""
]
@@ -470,7 +469,7 @@
"\n",
"特别是,`tff.SERVER` 可能是单个物理设备(单一实例组的成员),但它也可能是运行状态机复制的容错集群中的一组副本,我们不做任何特殊架构的假设。相反,我们使用前一部分提到的 `all_equal` 位来表达我们通常只在服务器上处理单个数据项这一事实。\n",
"\n",
- "同样,在某些应用中,`tff.CLIENTS` 可能代表系统中的所有客户端,在联合学习的上下文中,我们有时将其称为*群体*,但在[联合平均的生产实现](https://arxiv.org/abs/1602.05629)这个示例中,它可能代表*队列*(选择参加某轮训练的客户端的子集)。当部署计算以执行(或者就像模型环境中的 Python 函数那样被调用)时,其中的抽象定义布局将被赋予具体含义。在我们的本地模拟中,客户端组由作为输入提供的联合数据来确定。"
+ "同样,在某些应用中,`tff.CLIENTS` 可能代表系统中的所有客户端,在联合学习的上下文中,我们有时将其称为*总体*,但在[联合平均的生产实现](https://arxiv.org/abs/1602.05629)这个示例中,它可能代表*队列*(选择参加某轮训练的客户端的子集)。当对计算进行部署执行(或者就像模型环境中的 Python 函数那样被调用)时,出现在其中的抽象定义的位置将被赋予具体含义。在我们的本地模拟中,客户端组由作为输入提供的联合数据来确定。"
]
},
{
diff --git a/site/zh-cn/federated/tutorials/federated_learning_for_text_generation.ipynb b/site/zh-cn/federated/tutorials/federated_learning_for_text_generation.ipynb
index c9b21515d2..ad7dc20bcf 100644
--- a/site/zh-cn/federated/tutorials/federated_learning_for_text_generation.ipynb
+++ b/site/zh-cn/federated/tutorials/federated_learning_for_text_generation.ipynb
@@ -47,11 +47,9 @@
},
"source": [
""
]
@@ -193,7 +191,7 @@
"outputs": [],
"source": [
"def generate_text(model, start_string):\n",
- " # From https://tensorflow.google.cn/tutorials/sequences/text_generation\n",
+ " # From https://www.tensorflow.org/tutorials/sequences/text_generation\n",
" num_generate = 200\n",
" input_eval = [char2idx[s] for s in start_string]\n",
" input_eval = tf.expand_dims(input_eval, 0)\n",
diff --git a/site/zh-cn/guide/core/mlp_core.ipynb b/site/zh-cn/guide/core/mlp_core.ipynb
index bdbb9bf563..fb183dc9d4 100644
--- a/site/zh-cn/guide/core/mlp_core.ipynb
+++ b/site/zh-cn/guide/core/mlp_core.ipynb
@@ -47,14 +47,12 @@
},
"source": [
""
]
},
@@ -90,7 +88,7 @@
"\n",
"建立感知器堆栈时,它们会构成称为密集层的结构,随后可连接以构建神经网络。密集层的方程与感知器的方程类似,区别是使用权重矩阵和偏差向量:\n",
"\n",
- "$$Y = \\mathrm{W}⋅\\mathrm{X} + \\vec{b}$$\n",
+ "$$Z = \\mathrm{W}⋅\\mathrm{X} + \\vec{b}$$\n",
"\n",
"其中\n",
"\n",
@@ -229,7 +227,7 @@
},
"outputs": [],
"source": [
- "sns.countplot(y_viz.numpy());\n",
+ "sns.countplot(x=y_viz.numpy());\n",
"plt.xlabel('Digits')\n",
"plt.title(\"MNIST Digit Distribution\");"
]
@@ -381,8 +379,8 @@
" if not self.built:\n",
" # Infer the input dimension based on first call\n",
" self.in_dim = x.shape[1]\n",
- " # Initialize the weights and biases using Xavier scheme\n",
- " self.w = tf.Variable(xavier_init(shape=(self.in_dim, self.out_dim)))\n",
+ " # Initialize the weights and biases\n",
+ " self.w = tf.Variable(self.weight_init(shape=(self.in_dim, self.out_dim)))\n",
" self.b = tf.Variable(tf.zeros(shape=(self.out_dim,)))\n",
" self.built = True\n",
" # Compute the forward pass\n",
@@ -720,7 +718,7 @@
"id": "tbrJJaFrD_XR"
},
"source": [
- "## 保存和加载模型\n",
+ "## \t保存和加载模型\n",
"\n",
"首先,构建一个接受原始数据并执行以下运算的导出模块:\n",
"\n",
@@ -870,9 +868,9 @@
" label_ind = (y_test == label)\n",
" # extract predictions for specific true label\n",
" pred_label = test_classes[label_ind]\n",
- " label_filled = tf.cast(tf.fill(pred_label.shape[0], label), tf.int64)\n",
+ " labels = y_test[label_ind]\n",
" # compute class-wise accuracy\n",
- " label_accs[accuracy_score(pred_label, label_filled).numpy()] = label\n",
+ " label_accs[accuracy_score(pred_label, labels).numpy()] = label\n",
"for key in sorted(label_accs):\n",
" print(f\"Digit {label_accs[key]}: {key:.3f}\")"
]
@@ -883,7 +881,7 @@
"id": "rcykuJFhdGb0"
},
"source": [
- "该模型在对某些数字进行分类时的性能似乎稍逊于其他数字,这种情况在许多多类分类问题中都十分常见。作为最后的练习,请绘制出模型预测的混淆矩阵及其对应的真实标签,以便在类级别收集更多见解。Sklearn 和 Seaborn 中具有生成和可视化混淆矩阵的函数。 "
+ "该模型在对某些数字进行分类时的性能似乎稍逊于其他数字,这种情况在许多多类分类问题中都十分常见。作为最后的练习,请绘制出模型预测的混淆矩阵及其对应的真实标签,以便在分类方面获得更深入的洞察力。Sklearn 和 Seaborn 中具有生成和可视化混淆矩阵的函数。 "
]
},
{
@@ -901,7 +899,7 @@
" plt.figure(figsize=(10,10))\n",
" confusion = sk_metrics.confusion_matrix(test_labels.numpy(), \n",
" test_classes.numpy())\n",
- " confusion_normalized = confusion / confusion.sum(axis=1)\n",
+ " confusion_normalized = confusion / confusion.sum(axis=1, keepdims=True)\n",
" axis_labels = range(10)\n",
" ax = sns.heatmap(\n",
" confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,\n",
@@ -936,7 +934,7 @@
"- 初始化方案有助于防止模型参数在训练期间消失或爆炸。\n",
"- 过拟合是神经网络的另一个常见问题,但本教程不存在此问题。有关这方面的更多帮助,请参阅[过拟合与欠拟合](overfit_and_underfit.ipynb)教程。\n",
"\n",
- "有关使用 TensorFlow Core API 的更多示例,请查阅[指南](https://tensorflow.google.cn/guide/core)。如果您想详细了解如何加载和准备数据,请参阅有关[图像数据加载](https://tensorflow.google.cn/tutorials/load_data/images)或 [CSV 数据加载](https://tensorflow.google.cn/tutorials/load_data/csv)的教程。"
+ "有关使用 TensorFlow Core API 的更多示例,请查阅[教程](https://tensorflow.google.cn/guide/core)。如果您想详细了解如何加载和准备数据,请参阅有关[图像数据加载](https://tensorflow.google.cn/tutorials/load_data/images)或 [CSV 数据加载](https://tensorflow.google.cn/tutorials/load_data/csv)的教程。"
]
}
],
diff --git a/site/zh-cn/guide/create_op.md b/site/zh-cn/guide/create_op.md
index b2590b9c9b..748209810a 100644
--- a/site/zh-cn/guide/create_op.md
+++ b/site/zh-cn/guide/create_op.md
@@ -21,7 +21,7 @@
### 前提条件
- 对 C++ 有一定的了解。
-- 必须已安装 [TensorFlow 二进制文件](../../install),或者必须已[下载 TensorFlow 源代码](../../install/source.md),并且能够构建。
+- 必须已安装 [TensorFlow 二进制文件](https://www.tensorflow.org/install),或者必须已[下载 TensorFlow 源代码](https://www.tensorflow.org/install/source),并且能够构建。
## 定义运算接口
@@ -1083,6 +1083,8 @@ def _zero_out_grad(op, grad):
请注意,在调用梯度函数时,只有运算的数据流图可用,而张量数据本身不可用。因此,必须使用其他 TensorFlow 运算执行所有计算,以在计算图执行时运行。
+在为运算类型注册自定义梯度时添加类型提示,使代码更具可读性、可调试性、更易于维护,并且通过数据验证更加稳健。例如,在函数中采用 `op` 作为参数时,指定梯度函数将采用 tf.Operation
作为其参数类型 。
+
### C++ 中的形状函数
TensorFlow API 具有一项称为“形状推断”的功能,该功能可以提供有关张量形状的信息,而无需执行计算图。形状推断由在 C++ `REGISTER_OP` 声明中为每个运算类型注册的“形状函数”提供支持,并承担两个角色:声明输入的形状在计算图构造期间是兼容的,并指定输出的形状。
diff --git a/site/zh-cn/guide/distributed_training.ipynb b/site/zh-cn/guide/distributed_training.ipynb
index d4a2cae6fa..e179745a26 100644
--- a/site/zh-cn/guide/distributed_training.ipynb
+++ b/site/zh-cn/guide/distributed_training.ipynb
@@ -47,10 +47,14 @@
},
"source": [
""
]
},
@@ -60,7 +64,7 @@
"id": "xHxb-dlhMIzW"
},
"source": [
- "## 概述\n",
+ "## 文本特征向量\n",
"\n",
"`tf.distribute.Strategy` 是一个可在多个 GPU、多台机器或 TPU 上进行分布式训练的 TensorFlow API。使用此 API,您只需改动较少代码就能分布现有模型和训练代码。\n",
"\n",
@@ -119,8 +123,8 @@
"训练 API | `MirroredStrategy` | `TPUStrategy` | `MultiWorkerMirroredStrategy` | `CentralStorageStrategy` | `ParameterServerStrategy`\n",
":-- | :-- | :-- | :-- | :-- | :--\n",
"**Keras `Model.fit`** | 支持 | 支持 | 支持 | 实验性支持 | 实验性支持\n",
- "**自定义训练循环** | 支持 | 支持 | 支持 | 实验性支持 | 实验性支持\n",
- "**Estimator API** | 有限支持 | 不受支持 | 有限支持 | 有限支持 | 有限支持\n",
+ "**自定义训练循环** | 支持 | 支持 | 支持 | 实验性的支持 | 实验性的支持\n",
+ "**Estimator API** | 有限支持 | 不支持 | 有限支持 | 有限支持 | 有限支持\n",
"\n",
"注:[实验性支持](https://tensorflow.google.cn/guide/versions#what_is_not_covered)是指兼容性保证不涵盖这些 API。\n",
"\n",
@@ -204,9 +208,9 @@
"source": [
"### TPUStrategy\n",
"\n",
- "您可以使用 `tf.distribute.experimental.TPUStrategy` 在张量处理单元 (TPU) 上运行 TensorFlow 训练。TPU 是 Google 的专用 ASIC,旨在显著加速机器学习工作负载。您可通过 Google Colab、[TensorFlow Research Cloud](https://tensorflow.google.cn/tfrc) 和 [Cloud TPU](https://cloud.google.com/tpu) 平台进行使用。\n",
+ "您可以使用 `tf.distribute.experimental.TPUStrategy` 在张量处理单元 (TPU) 上运行 TensorFlow 训练。TPU 是 Google 的专用 ASIC,旨在显著加速机器学习工作负载。您可以通过 Google Colab、[TensorFlow Research Cloud](https://tensorflow.google.cn/tfrc) 和 [Cloud TPU](https://cloud.google.com/tpu) 平台进行使用。\n",
"\n",
- "就分布式训练架构而言,`TPUStrategy` 和 `MirroredStrategy` 是一样的,即实现同步分布式训练。TPU 会在多个 TPU 核心之间实现高效的全归约和其他集合运算,并将其用于 `TPUStrategy`。\n",
+ "就分布式训练架构而言,`TPUStrategy` 与 `MirroredStrategy` 相同,即实现同步分布式训练。TPU 会在多个 TPU 核心之间实现高效的全归约和其他集合运算,并将其用于 `TPUStrategy`。\n",
"\n",
"下面演示了如何将 `TPUStrategy` 实例化:\n",
"\n",
@@ -297,7 +301,7 @@
"source": [
"### ParameterServerStrategy\n",
"\n",
- "参数服务器训练是一种常见的数据并行方法,用于在多台机器上扩展模型训练。参数服务器训练集群由工作进程和参数服务器组成。变量在参数服务器上创建,并在每个步骤中由工作进程读取和更新。查看[参数服务器培训](../tutorials/distribute/parameter_server_training.ipynb)教程了解详情。\n",
+ "参数服务器训练是一种常见的数据并行方法,用于在多台机器上扩展模型训练。参数服务器训练集群由工作进程和参数服务器组成。变量在参数服务器上创建,并在每个步骤中由工作进程读取和更新。请查阅[参数服务器培训](../tutorials/distribute/parameter_server_training.ipynb)教程以了解详情。\n",
"\n",
"在 TensorFlow 2 中,参数服务器训练通过 `tf.distribute.experimental.coordinator.Cluster Coordinator` 类使用基于中央协调器的架构。\n",
"\n",
@@ -313,9 +317,9 @@
" strategy)\n",
"```\n",
"\n",
- "要了解有关 `ParameterServerStrategy` 的详细信息,请参阅[使用 Keras Model.fit 和自定义训练循环进行参数服务器训练](../tutorials/distribute/parameter_server_training.ipynb)教程。\n",
+ "要详细了解 `ParameterServerStrategy`,请参阅[使用 Keras Model.fit 和自定义训练循环进行参数服务器训练](../tutorials/distribute/parameter_server_training.ipynb)教程。\n",
"\n",
- "注:如果使用 `TFConfigClusterResolver`,则需要配置 `'TF_CONFIG'` 环境变量。它类似于 `MultiWorkerMirroredStrategy` 中的 `'TF_CONFIG'`,但具有额外的注意事项。\n",
+ "注:如果使用 `TFConfigClusterResolver`,则需要配置 `'TF_CONFIG'` 环境变量。它类似于 MultiWorkerMirroredStrategy
中的 `'TF_CONFIG'`,但具有额外的注意事项。\n",
"\n",
"在 TensorFlow 1 中,`ParameterServerStrategy`只能通过 `tf.compat.v1.distribute.experimental.ParameterServerStrategy` 符号在 Estimator 中使用。"
]
@@ -339,7 +343,7 @@
"\n",
"`tf.distribute.experimental.CentralStorageStrategy` 也执行同步训练。变量不会被镜像,而是放在 CPU 上,且运算会复制到所有本地 GPU 。如果只有一个 GPU,则所有变量和运算都将被放在该 GPU 上。\n",
"\n",
- "请通过以下代码,创建 `CentralStorageStrategy` 实例:\n"
+ "请通过以下代码创建 `CentralStorageStrategy` 实例:\n"
]
},
{
@@ -392,7 +396,7 @@
"\n",
"默认策略是一种分布策略,当作用域内没有显式分布策略时就会出现。此策略会实现 `tf.distribute.Strategy` 接口,但只具有传递功能,不提供实际分布。例如,`Strategy.run(fn)` 只会调用 `fn`。使用该策略编写的代码与未使用任何策略编写的代码完全一样。您可以将其视为“无运算”策略。\n",
"\n",
- "默认策略是一种单例,无法创建它的更多实例。可以在任何显式策略作用域之外使用 `tf.distribute.get_strategy` 来获取它(可用于在显式策略作用域内获取当前策略的相同 API)。"
+ "默认策略是一种单一实例,无法创建它的更多实例。可以在任何显式策略范围之外使用 `tf.distribute.get_strategy` 来获取它(可用于在显式策略范围内获取当前策略的相同 API)。"
]
},
{
@@ -414,7 +418,7 @@
"source": [
"此策略有两个主要用途:\n",
"\n",
- "- 它允许无条件地编写可感知分发的库代码。例如,在 `tf.keras.optimizers` 中,您可以使用 `tf.distribute.get_strategy`,并用此策略来降低梯度 - 它将始终返回一个策略对象,您可以在该对象上调用 `Strategy.reduce` API。\n"
+ "- 它允许无条件地编写可感知分布的库代码。例如,在 `tf.keras.optimizers` 中,您可以使用 `tf.distribute.get_strategy`,并用此策略来降低梯度 – 它将始终返回一个策略对象,您可以在该对象上调用 `Strategy.reduce` API。\n"
]
},
{
@@ -472,7 +476,7 @@
"strategy = tf.distribute.OneDeviceStrategy(device=\"/gpu:0\")\n",
"```\n",
"\n",
- "此策略在许多方面与默认策略不同。在默认策略中,与在未使用任何分布策略的情况下运行 TensorFlow 相比,变量布局逻辑保持不变。但是,在使用 `OneDeviceStrategy` 时,在其作用域内创建的所有变量都会显式放置在指定的设备上。此外,通过 `OneDeviceStrategy.run` 调用的任何函数也将放置在指定的设备上。\n",
+ "此策略在许多方面与默认策略不同。在默认策略中,与在未使用任何分布策略的情况下运行 TensorFlow 相比,变量布局逻辑保持不变。但是,在使用 `OneDeviceStrategy` 时,在其范围内创建的所有变量都会显式放置在指定的设备上。此外,通过 `OneDeviceStrategy.run` 调用的任何函数也将放置在指定的设备上。\n",
"\n",
"通过此策略分布的输入将被预获取到指定设备。在默认策略中,没有输入分布。\n",
"\n",
@@ -496,7 +500,7 @@
"source": [
"## 在 Keras Model.fit 中使用 tf.distribute.Strategy\n",
"\n",
- "`tf.distribute.Strategy` 已集成到 `tf.keras` 中,后者是 TensorFlow 对 [Keras API 规范](https://keras.io/api/)的实现。`tf.keras` 是用于构建和训练模型的高级 API。通过集成到 `tf.keras` 后端,您可以无缝[使用 Model.fit](https://tensorflow.google.cn/guide/keras/customizing_what_happens_in_fit) 来分布以 Keras 训练框架编写的训练。\n",
+ "`tf.distribute.Strategy` 已集成到 `tf.keras` 中,后者是 TensorFlow 对 [Keras API 规范](https://keras.io/api/)的实现。`tf.keras` 是用于构建和训练模型的高级 API。通过集成到 `tf.keras` 后端,您可以无缝使用 Model.fit
来分布以 Keras 训练框架编写的训练。\n",
"\n",
"您需要对代码进行以下更改:\n",
"\n",
@@ -519,9 +523,10 @@
"mirrored_strategy = tf.distribute.MirroredStrategy()\n",
"\n",
"with mirrored_strategy.scope():\n",
- " model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])\n",
- "\n",
- "model.compile(loss='mse', optimizer='sgd')"
+ " model = tf.keras.Sequential([\n",
+ " tf.keras.layers.Dense(1, input_shape=(1,),\n",
+ " kernel_regularizer=tf.keras.regularizers.L2(1e-4))])\n",
+ " model.compile(loss='mse', optimizer='sgd')"
]
},
{
@@ -667,7 +672,9 @@
"outputs": [],
"source": [
"with mirrored_strategy.scope():\n",
- " model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])\n",
+ " model = tf.keras.Sequential([\n",
+ " tf.keras.layers.Dense(1, input_shape=(1,),\n",
+ " kernel_regularizer=tf.keras.regularizers.L2(1e-4))])\n",
" optimizer = tf.keras.optimizers.SGD()"
]
},
@@ -710,20 +717,21 @@
},
"outputs": [],
"source": [
+ "# Sets `reduction=NONE` to leave it to tf.nn.compute_average_loss() below.\n",
"loss_object = tf.keras.losses.BinaryCrossentropy(\n",
" from_logits=True,\n",
" reduction=tf.keras.losses.Reduction.NONE)\n",
"\n",
- "def compute_loss(labels, predictions):\n",
- " per_example_loss = loss_object(labels, predictions)\n",
- " return tf.nn.compute_average_loss(per_example_loss, global_batch_size=global_batch_size)\n",
- "\n",
"def train_step(inputs):\n",
" features, labels = inputs\n",
"\n",
" with tf.GradientTape() as tape:\n",
" predictions = model(features, training=True)\n",
- " loss = compute_loss(labels, predictions)\n",
+ " per_example_loss = loss_object(labels, predictions)\n",
+ " loss = tf.nn.compute_average_loss(per_example_loss)\n",
+ " model_losses = model.losses\n",
+ " if model_losses:\n",
+ " loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))\n",
"\n",
" gradients = tape.gradient(loss, model.trainable_variables)\n",
" optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
@@ -744,9 +752,15 @@
"source": [
"以上代码还需注意以下几点:\n",
"\n",
- "1. 您使用了 `tf.nn.compute_average_loss` 来计算损失。`tf.nn.compute_average_loss` 将每个样本的损失相加,然后将总和除以 `global_batch_size`。这很重要,因为稍后在每个副本上计算出梯度后,会通过对它们**求和**使其在副本中聚合。\n",
- "2. 您还使用了 `tf.distribute.Strategy.reduce` API 来聚合 `tf.distribute.Strategy.run` 返回的结果。`tf.distribute.Strategy.run` 会从策略中的每个本地副本返回结果,您可以通过多种方式使用此结果。可以 `reduce` 它们以获得聚合值。还可以通过执行 `tf.distribute.Strategy.experimental_local_results` 获得包含在结果中的值的列表,每个本地副本一个列表。\n",
- "3. 当在一个分布策略作用域内调用 `apply_gradients` 时,它的行为会被修改。具体来说,在同步训练期间,在将梯度应用于每个并行实例之前,它会对梯度的所有副本求和。\n"
+ "1. 您使用了 `tf.nn.compute_average_loss` 将每个样本的预测损失减少到标量。`tf.nn.compute_average_loss` 将每个样本的损失相加,然后将总和除以全局批次大小。这很重要,因为稍后在每个副本上计算出梯度后,会通过对它们**求和**使其在副本中聚合。\n",
+ "\n",
+ "默认情况下,全局批次大小为 `tf.get_strategy().num_replicas_in_sync * tf.shape(per_example_loss)[0]`。也可以将其显式指定为关键字参数 `global_batch_size=`。如果没有短批次,默认值相当于 `tf.nn.compute_average_loss(..., global_batch_size=global_batch_size)` 和上面定义的 `global_batch_size`。(有关短批次以及如何避免或处理它们的更多信息,请参阅[自定义训练](../tutorials/distribute/custom_training.ipynb)教程。)\n",
+ "\n",
+ "1. 您还使用 `tf.nn.scale_regularization_loss` 将通过 `Model` 对象注册的正则化损失(如果有)缩放 `1/num_replicas_in_sync`。对于那些依赖于输入的正则化损失,它取决于建模代码,而不是自定义训练循环,来对每个副本(!)批次大小求平均值;这样,建模代码就可以保持与复制无关,而训练循环仍然与如何计算正则化损失无关。\n",
+ "\n",
+ "2. 当您在一个分布策略作用域内调用 `apply_gradients` 时,它的行为会被修改。具体来说,同步训练期间,在将梯度应用于每个并行实例之前,它会对梯度的所有副本求和。\n",
+ "\n",
+ "3. 您还使用了 `tf.distribute.Strategy.reduce` API 来聚合 `tf.distribute.Strategy.run` 返回的结果以进行报告。`tf.distribute.Strategy.run` 会从策略中的每个本地副本返回结果,您可以通过多种方式使用此结果。可以 `reduce` 它们以获得聚合值。还可以通过执行 `tf.distribute.Strategy.experimental_local_results` 获得包含在结果中的值的列表,每个本地副本一个列表。\n"
]
},
{
diff --git a/site/zh-cn/guide/dtensor_overview.ipynb b/site/zh-cn/guide/dtensor_overview.ipynb
index 4d1bbe06cd..29b47cfe9e 100644
--- a/site/zh-cn/guide/dtensor_overview.ipynb
+++ b/site/zh-cn/guide/dtensor_overview.ipynb
@@ -167,7 +167,7 @@
"在一维 `Mesh` 中,所有设备会以单一网格维度构成列表。以下示例使用 `dtensor.create_mesh` 从 6 个 CPU 设备沿网格维度 `'x'` 创建了一个网格,大小为 6 个设备:\n",
"\n",
"\n",
- " \n"
+ " \n"
]
},
{
@@ -262,7 +262,7 @@
"使用相同的张量和网格,布局 `Layout(['unsharded', 'x'])` 将在 6 个设备上对张量的第二个轴进行分片。\n",
"\n",
"\n",
- " "
+ " "
]
},
{
@@ -291,7 +291,7 @@
"id": "Eyp_qOSyvieo"
},
"source": [
- " \n"
+ " \n"
]
},
{
@@ -314,7 +314,7 @@
"对于同一 `mesh_2d`,布局 `Layout([\"x\", dtensor.UNSHARDED], mesh_2d)` 是跨 `\"y\"` 复制的 2 秩 `Tensor` 的布局,其第一个轴在网格维度 `x` 上分片。\n",
"\n",
"\n",
- " \n"
+ " \n"
]
},
{
diff --git a/site/zh-cn/guide/function.ipynb b/site/zh-cn/guide/function.ipynb
index 6d81b5f465..71c1e94b9f 100644
--- a/site/zh-cn/guide/function.ipynb
+++ b/site/zh-cn/guide/function.ipynb
@@ -37,14 +37,12 @@
"id": "6DWfyNThSziV"
},
"source": [
- "# 使用 tf.function 提升性能\n",
+ "# 使用 tf.function 时提升性能\n",
"\n",
""
]
@@ -55,7 +53,7 @@
"id": "J122XQYG7W6w"
},
"source": [
- "在 TensorFlow 2 中,[Eager Execution](eager.ipynb) 默认处于启用状态。界面非常灵活直观(执行一次性运算要简单快速得多),不过,这可能对性能和可部署性造成一定影响。\n",
+ "在 TensorFlow 2 中,Eager Execution 默认处于启用状态。界面非常灵活直观(执行一次性运算要简单快速得多),不过,这可能对性能和可部署性造成一定影响。\n",
"\n",
"您可以使用 `tf.function` 将程序转换为计算图。这是一个转换工具,用于从 Python 代码创建独立于 Python 的数据流图。它可以帮助您创建高效且可移植的模型,并且如果要使用 `SavedModel`,则必须使用此工具。\n",
"\n",
@@ -64,7 +62,7 @@
"要点和建议包括:\n",
"\n",
"- 先在 Eager 模式下调试,然后使用 `@tf.function` 进行装饰。\n",
- "- 不依赖 Python 的副作用,如对象变异或列表追加。\n",
+ "- 不依赖 Python 副作用,如对象变异或列表追加。\n",
"- `tf.function` 最适合处理 TensorFlow 运算;NumPy 和 Python 调用会转换为常量。\n"
]
},
@@ -74,7 +72,7 @@
"id": "SjvqpgepHJPd"
},
"source": [
- "## 设置"
+ "## 安装"
]
},
{
@@ -85,8 +83,6 @@
},
"outputs": [],
"source": [
- "# Update TensorFlow, as this notebook requires version 2.9 or later\n",
- "!pip install -q -U tensorflow>=2.9.0\n",
"import tensorflow as tf"
]
},
@@ -344,7 +340,7 @@
"- `tf.Graph` 与语言无关,是 TensorFlow 计算的原始可移植表示。\n",
"- `ConcreteFunction` 封装 `tf.Graph`。\n",
"- `Function` 管理 `ConcreteFunction` 的缓存,并为输入选择正确的缓存。\n",
- "- `tf.function` 包装 Python 函数,并返回一个 `Function` 对象。\n",
+ "- `tf.function` 封装 Python 函数,并返回一个 `Function` 对象。\n",
"- **跟踪**会创建 `tf.Graph` 并将其封装在 `ConcreteFunction` 中,也称为**跟踪**。\n"
]
},
@@ -363,12 +359,23 @@
"`TraceType` 由输入参数确定,具体如下所示:\n",
"\n",
"- 对于 `Tensor`,类型由 `Tensor` 的 `dtype` 和 `shape` 参数化;有秩形状是无秩形状的子类型;固定维度是未知维度的子类型\n",
+ "\n",
"- 对于 `Variable`,类型类似于 `Tensor`,但还包括变量的唯一资源 ID,这是正确连接控制依赖项所必需的\n",
+ "\n",
"- 对于 Python 基元值,类型对应于**值**本身。例如,值为 `3` 的 `TraceType` 是 `LiteralTraceType<3>`,而不是 `int`。\n",
+ "\n",
"- 对于 `list` 和 `tuple` 等 Python 有序容器,类型是通过其元素的类型来参数化的;例如,`[1, 2]` 的类型是 `ListTraceType, LiteralTraceType<2>>`,`[2, 1]` 的类型是 `ListTraceType, LiteralTraceType<1>>`,两者不同。\n",
+ "\n",
"- 对于 `dict` 等 Python 映射,类型也是从相同的键到值类型而不是实际值的映射。例如,`{1: 2, 3: 4}` 的类型为 `MappingTraceType<>>, >>>`。但是,与有序容器不同的是,`{1: 2, 3: 4}` 和 `{3: 4, 1: 2}` 具有等价的类型。\n",
+ "\n",
"- 对于实现 `__tf_tracing_type__` 方法的 Python 对象,类型为该方法返回的任何内容\n",
- "- 对于任何其他 Python 对象,类型是通用的 `TraceType`,它使用对象的 Python 相等性和散列进行匹配。(注:它依赖于对对象的[弱引用](https://docs.python.org/3/library/weakref.html),因此仅在对象处于范围内/未被删除时才有效。)\n"
+ "\n",
+ "- 对于任何其他 Python 对象,类型是通用的 `TraceType`,匹配过程如下:\n",
+ "\n",
+ " - 首先,它检查该对象与先前跟踪中使用的对象是否相同(使用 `id()` 或 `is`)。请注意,如果对象已更改,这仍然会匹配,因此如果您使用 Python 对象作为 `tf.function` 参数,最好使用*不可变*对象。\n",
+ " - 接下来,它检查该对象是否等于先前跟踪中使用的对象(使用 python `==`)。\n",
+ "\n",
+ " 请注意,此过程仅保留对象的[弱引用](https://docs.python.org/3/library/weakref.html),因此仅在对象处于范围内/未被删除时有效。)\n"
]
},
{
@@ -417,11 +424,11 @@
"\n",
"print(next_collatz(tf.constant([1, 2])))\n",
"# You specified a 1-D tensor in the input signature, so this should fail.\n",
- "with assert_raises(ValueError):\n",
+ "with assert_raises(TypeError):\n",
" next_collatz(tf.constant([[1, 2], [3, 4]]))\n",
"\n",
"# You specified an int32 dtype in the input signature, so this should fail.\n",
- "with assert_raises(ValueError):\n",
+ "with assert_raises(TypeError):\n",
" next_collatz(tf.constant([1.0, 2.0]))\n"
]
},
@@ -555,8 +562,8 @@
" flavor = tf.constant([3, 4])\n",
"\n",
"# As described in the above rules, a generic TraceType for `Apple` and `Mango`\n",
- "# is generated (and a corresponding ConcreteFunction is traced) but it fails to \n",
- "# match the second function call since the first pair of Apple() and Mango() \n",
+ "# is generated (and a corresponding ConcreteFunction is traced) but it fails to\n",
+ "# match the second function call since the first pair of Apple() and Mango()\n",
"# have gone out out of scope by then and deleted.\n",
"get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function\n",
"get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again\n",
@@ -567,26 +574,33 @@
"# can have significant performance benefits.\n",
"\n",
"class FruitTraceType(tf.types.experimental.TraceType):\n",
- " def __init__(self, fruit_type):\n",
- " self.fruit_type = fruit_type\n",
+ " def __init__(self, fruit):\n",
+ " self.fruit_type = type(fruit)\n",
+ " self.fruit_value = fruit\n",
"\n",
" def is_subtype_of(self, other):\n",
+ " # True if self subtypes `other` and `other`'s type matches FruitTraceType.\n",
" return (type(other) is FruitTraceType and\n",
" self.fruit_type is other.fruit_type)\n",
"\n",
" def most_specific_common_supertype(self, others):\n",
+ " # `self` is the specific common supertype if all input types match it.\n",
" return self if all(self == other for other in others) else None\n",
"\n",
+ " def placeholder_value(self, placeholder_context=None):\n",
+ " # Use the fruit itself instead of the type for correct tracing.\n",
+ " return self.fruit_value\n",
+ "\n",
" def __eq__(self, other):\n",
" return type(other) is FruitTraceType and self.fruit_type == other.fruit_type\n",
- " \n",
+ "\n",
" def __hash__(self):\n",
" return hash(self.fruit_type)\n",
"\n",
"class FruitWithTraceType:\n",
"\n",
" def __tf_tracing_type__(self, context):\n",
- " return FruitTraceType(type(self))\n",
+ " return FruitTraceType(self)\n",
"\n",
"class AppleWithTraceType(FruitWithTraceType):\n",
" flavor = tf.constant([1, 2])\n",
@@ -685,7 +699,7 @@
"id": "lar5A_5m5IG1"
},
"source": [
- "对不兼容的类型使用具体跟踪会引发错误"
+ "对不兼容的类型使用具体跟踪记录会引发错误"
]
},
{
@@ -706,7 +720,7 @@
"id": "st2L9VNQVtSG"
},
"source": [
- "您可能会注意到,在具体函数的输入签名中对 Python 参数进行了特别处理。TensorFlow 2.3 之前的版本会将 Python 参数直接从具体函数的签名中删除。从 TensorFlow 2.3 开始,Python 参数会保留在签名中,但是会受到约束,只能获取在跟踪期间设置的值。"
+ "您可能会注意到,在具体函数的输入签名中对 Python 参数进行了特别处理。TensorFlow 2.3 之前的版本会将 Python 参数直接从具体函数的签名中移除。从 TensorFlow 2.3 开始,Python 参数会保留在签名中,但是会受到约束,只能获取在跟踪期间设置的值。"
]
},
{
@@ -747,7 +761,7 @@
"source": [
"### 获取计算图\n",
"\n",
- "每个具体函数都是 `tf.Graph` 的可调用包装器。虽然一般不需要检索实际 `tf.Graph` 对象,不过,您可以从任何具体函数轻松获得实际对象。"
+ "每个具体函数都是 `tf.Graph` 的可调用封装容器。虽然一般不需要检索实际 `tf.Graph` 对象,不过,您可以从任何具体函数轻松获得实际对象。"
]
},
{
@@ -771,11 +785,11 @@
"source": [
"### 调试\n",
"\n",
- "通常,在 Eager 模式下调试代码比在 `tf.function` 中简单。在使用 `tf.function` 进行装饰之前,进行装饰之前,您应该先确保代码可在 Eager 模式下无错误执行。为了帮助调试,您可以调用 `tf.config.run_functions_eagerly(True)` 来全局停用和重新启用 `tf.function`。\n",
+ "通常,在 Eager 模式下调试代码比在 `tf.function` 中简单。在使用 `tf.function` 进行装饰之前,您应该先确保代码可在 Eager 模式下无错误执行。为了帮助调试,您可以调用 `tf.config.run_functions_eagerly(True)` 来全局停用和重新启用 `tf.function`。\n",
"\n",
"追溯仅在 `tf.function` 中出现的问题时,可参考下面的几点提示:\n",
"\n",
- "- 普通旧 Python `print` 调用仅在跟踪期间执行,可以帮助您在(重新)跟踪函数时进行追溯。\n",
+ "- 普通旧 Python `print` 调用仅在跟踪期间执行,可用于追溯(重新)跟踪函数的时间。\n",
"- `tf.print` 调用每次都会执行,可用于追溯执行过程中产生的中间值。\n",
"- 利用 `tf.debugging.enable_check_numerics` 很容易追溯到 NaN 和 Inf 在何处创建。\n",
"- `pdb`([Python 调试器](https://docs.python.org/3/library/pdb.html))可以帮助您理解跟踪的详细过程。(提醒:使用 `pdb` 调试时,AutoGraph 会自动转换 Python 源代码。)"
@@ -919,7 +933,7 @@
"\n",
"一个常见陷阱是在 `tf.function` 中的 Python/Numpy 数据上循环。此循环在跟踪过程中执行,因而循环每迭代一次,都会将模型的一个副本添加到 `tf.Graph`。\n",
"\n",
- "如果要在 `tf.function` 中包装整个训练循环,最安全的方法是将数据包装为 `tf.data.Dataset`,以便 AutoGraph 动态展开训练循环。"
+ "如果要在 `tf.function` 中封装整个训练循环,最安全的方式是将数据封装为 `tf.data.Dataset`,以便 AutoGraph 动态展开训练循环。"
]
},
{
@@ -972,7 +986,7 @@
"source": [
"#### 累加循环值\n",
"\n",
- "一种常见模式是不断累加循环的中间值。通常,这可以通过将元素追加到 Python 列表或将条目添加到 Python 字典来实现。但是,由于存在 Python 副作用,在动态展开循环中,这些方法无法达到预期效果。要从动态展开循环累加结果,可以使用 `tf.TensorArray` 来实现。"
+ "一种常见模式是不断累加循环的中间值。通常,这可以通过将元素追加到 Python 列表或将条目添加到 Python 字典来实现。但是,由于存在 Python 副作用,在动态展开循环中,这些方式无法达到预期效果。要从动态展开循环累加结果,可以使用 `tf.TensorArray` 来实现。"
]
},
{
@@ -1514,7 +1528,7 @@
"id": "hvwe9gTIWfx6"
},
"source": [
- "#### 取决于 Python 对象"
+ "### 依赖于 Python 对象"
]
},
{
@@ -1523,7 +1537,11 @@
"id": "BJkZS-SwPvOQ"
},
"source": [
- "将 Python 对象作为参数传递给 `tf.function` 的建议存在许多已知问题,预计会在以后得到解决。通常,如果您使用 Python 基元或兼容 `tf.nest` 的结构作为参数,或将对象的*不同*实例传递给 `Function`,则可以依赖稳定的跟踪。但是,如果您传递**同一对象并仅更改其特性**时,`Function` 将*不会*创建新的跟踪记录。"
+ "支持将自定义 Python 对象作为参数传递给 `tf.function`,但有一定的限制。\n",
+ "\n",
+ "为了获得最大的特征覆盖率,请考虑在将对象传递给 `tf.function` 之前将其转换为[扩展类型](extension_type.ipynb)。此外,您也可以使用 Python 基元以及与 `tf.nest` 兼容的结构。\n",
+ "\n",
+ "但是,正如[跟踪规则](#rules_of_tracing)中所述,当自定义 Python 类未提供自定义 `TraceType` 时,`tf.function` 被迫使用基于实例的相等性,这意味着当您传递**具有修改特性的同一对象**时,它将**不会创建新的跟踪记录**。"
]
},
{
@@ -1568,9 +1586,9 @@
"id": "Ytcgg2qFWaBF"
},
"source": [
- "如果使用相同的 `Function` 评估模型的更新实例,那么更新后的模型与原始模型将具有[相同的缓存键](#rules_of_tracing),所以这种做法并不合理。\n",
+ "使用相同的 `Function` 评估模型的修改实例并不合理,因为它仍然具有与原始模型[相同的基于实例的 TraceType](#rules_of_tracing)。\n",
"\n",
- "因此,建议您编写 `Function` 以避免依赖于可变对象特性,或者创建新对象。\n",
+ "因此,建议您编写 `Function` 以避免依赖于可变对象特性,或者为对象实现[跟踪协议](#use_the_tracing_protocol)以将此类特性通知给 `Function`。\n",
"\n",
"如果这不可行,则一种解决方法是,每次修改对象时都创建新的 `Function` 以强制回溯:"
]
@@ -1589,7 +1607,7 @@
"new_model = SimpleModel()\n",
"evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)\n",
"# Don't pass in `new_model`, `Function` already captured its state during tracing.\n",
- "print(evaluate_no_bias(x)) "
+ "print(evaluate_no_bias(x))"
]
},
{
@@ -1734,7 +1752,7 @@
"source": [
"opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)\n",
"opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n",
- " \n",
+ "\n",
"@tf.function\n",
"def train_step(w, x, y, optimizer):\n",
" with tf.GradientTape() as tape:\n",
@@ -1784,13 +1802,13 @@
"y = tf.constant([2.])\n",
"\n",
"# Make a new Function and ConcreteFunction for each optimizer.\n",
- "train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)\n",
- "train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)\n",
+ "train_step_1 = tf.function(train_step)\n",
+ "train_step_2 = tf.function(train_step)\n",
"for i in range(10):\n",
" if i % 2 == 0:\n",
- " train_step_1(w, x, y) # `opt1` is not used as a parameter. \n",
+ " train_step_1(w, x, y, opt1)\n",
" else:\n",
- " train_step_2(w, x, y) # `opt2` is not used as a parameter."
+ " train_step_2(w, x, y, opt2)"
]
},
{
@@ -1820,7 +1838,6 @@
],
"metadata": {
"colab": {
- "collapsed_sections": [],
"name": "function.ipynb",
"toc_visible": true
},
diff --git a/site/zh-cn/guide/jax2tf.ipynb b/site/zh-cn/guide/jax2tf.ipynb
new file mode 100644
index 0000000000..c094b50204
--- /dev/null
+++ b/site/zh-cn/guide/jax2tf.ipynb
@@ -0,0 +1,850 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ckM5wJMsNTYL"
+ },
+ "source": [
+ "##### Copyright 2023 The TensorFlow Authors."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "NKvERjPVNWxu"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bqePLdDjNhNk"
+ },
+ "source": [
+ "# 使用 JAX2TF 导入 JAX 模型"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gw3w46yhNiK_"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "IyrsY3uTOmPY"
+ },
+ "source": [
+ "此笔记本提供了一个完整的可运行示例,说明如何使用 [JAX](https://jax.readthedocs.io/en/latest/) 创建模型并将其导入 TensorFlow 以继续训练。这是通过 [JAX2TF](https://github.com/google/jax/tree/main/jax/experimental/jax2tf) 实现的,JAX2TF 是一种轻量级 API,提供了从 JAX 生态系统到 TensorFlow 生态系统的途径。\n",
+ "\n",
+ "JAX 是一个高性能数组计算库。为了创建模型,此笔记本使用 [Flax](https://flax.readthedocs.io/en/latest/),这是一种用于 JAX 的神经网络库。为了进行训练,它使用了 [Optax](https://optax.readthedocs.io),这是一种用于 JAX 的优化库。\n",
+ "\n",
+ "如果您是使用 JAX 的研究人员,JAX2TF 为您提供了一条使用TensorFlow 成熟工具进行生产的路径。\n",
+ "\n",
+ "这可以有很多用途,以下只是其中几个:\n",
+ "\n",
+ "- 推断:采用为 JAX 编写的模型,并使用 TF Serving 将其部署在服务器上、使用 TFLite 在设备上部署或使用 TensorFlow.js 在 Web 上部署。\n",
+ "\n",
+ "- 微调:采用使用 JAX 训练的模型,您可以使用 JAX2TF 将其组件引入 TF,并使用您现有的训练数据和设置在 TensorFlow 中继续训练。\n",
+ "\n",
+ "- 融合:将使用 JAX 训练的模型部分与使用 TensorFlow 训练的模型部分相结合,以获得最大的灵活性。\n",
+ "\n",
+ "在 JAX 与 TensorFlow 之间实现这种互操作的关键是 `jax2tf.convert`,它接受在 JAX 上创建的模型组件(您的损失函数、预测函数等)并作为 TensorFlow 函数创建它们的等效表示,然后可以将其导出为 TensorFlow SavedModel。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "G6rtu96yOepm"
+ },
+ "source": [
+ "## 安装\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "9yqxfHzr0LPF"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "import numpy as np\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import flax\n",
+ "import optax\n",
+ "import os\n",
+ "from matplotlib import pyplot as plt\n",
+ "from jax.experimental import jax2tf\n",
+ "from threading import Lock # Only used in the visualization utility.\n",
+ "from functools import partial"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "SDnTaZO0r872"
+ },
+ "outputs": [],
+ "source": [
+ "# Needed for TensorFlow and JAX to coexist in GPU memory.\n",
+ "os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = \"false\"\n",
+ "gpus = tf.config.list_physical_devices('GPU')\n",
+ "if gpus:\n",
+ " try:\n",
+ " for gpu in gpus:\n",
+ " tf.config.experimental.set_memory_growth(gpu, True)\n",
+ " except RuntimeError as e:\n",
+ " # Memory growth must be set before GPUs have been initialized.\n",
+ " print(e)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "BXOjCNJxDLil"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Visualization utilities\n",
+ "\n",
+ "plt.rcParams[\"figure.figsize\"] = (20,8)\n",
+ "\n",
+ "# The utility for displaying training and validation curves.\n",
+ "def display_train_curves(loss, avg_loss, eval_loss, eval_accuracy, epochs, steps_per_epochs, ignore_first_n=10):\n",
+ "\n",
+ " ignore_first_n_epochs = int(ignore_first_n/steps_per_epochs)\n",
+ "\n",
+ " # The losses.\n",
+ " ax = plt.subplot(121)\n",
+ " if loss is not None:\n",
+ " x = np.arange(len(loss)) / steps_per_epochs #* epochs\n",
+ " ax.plot(x, loss)\n",
+ " ax.plot(range(1, epochs+1), avg_loss, \"-o\", linewidth=3)\n",
+ " ax.plot(range(1, epochs+1), eval_loss, \"-o\", linewidth=3)\n",
+ " ax.set_title('Loss')\n",
+ " ax.set_ylabel('loss')\n",
+ " ax.set_xlabel('epoch')\n",
+ " if loss is not None:\n",
+ " ax.set_ylim(0, np.max(loss[ignore_first_n:]))\n",
+ " ax.legend(['train', 'avg train', 'eval'])\n",
+ " else:\n",
+ " ymin = np.min(avg_loss[ignore_first_n_epochs:])\n",
+ " ymax = np.max(avg_loss[ignore_first_n_epochs:])\n",
+ " ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10)\n",
+ " ax.legend(['avg train', 'eval'])\n",
+ "\n",
+ " # The accuracy.\n",
+ " ax = plt.subplot(122)\n",
+ " ax.set_title('Eval Accuracy')\n",
+ " ax.set_ylabel('accuracy')\n",
+ " ax.set_xlabel('epoch')\n",
+ " ymin = np.min(eval_accuracy[ignore_first_n_epochs:])\n",
+ " ymax = np.max(eval_accuracy[ignore_first_n_epochs:])\n",
+ " ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10)\n",
+ " ax.plot(range(1, epochs+1), eval_accuracy, \"-o\", linewidth=3)\n",
+ "\n",
+ "class Progress:\n",
+ " \"\"\"Text mode progress bar.\n",
+ " Usage:\n",
+ " p = Progress(30)\n",
+ " p.step()\n",
+ " p.step()\n",
+ " p.step(reset=True) # to restart form 0%\n",
+ " The progress bar displays a new header at each restart.\"\"\"\n",
+ " def __init__(self, maxi, size=100, msg=\"\"):\n",
+ " \"\"\"\n",
+ " :param maxi: the number of steps required to reach 100%\n",
+ " :param size: the number of characters taken on the screen by the progress bar\n",
+ " :param msg: the message displayed in the header of the progress bar\n",
+ " \"\"\"\n",
+ " self.maxi = maxi\n",
+ " self.p = self.__start_progress(maxi)() # `()`: to get the iterator from the generator.\n",
+ " self.header_printed = False\n",
+ " self.msg = msg\n",
+ " self.size = size\n",
+ " self.lock = Lock()\n",
+ "\n",
+ " def step(self, reset=False):\n",
+ " with self.lock:\n",
+ " if reset:\n",
+ " self.__init__(self.maxi, self.size, self.msg)\n",
+ " if not self.header_printed:\n",
+ " self.__print_header()\n",
+ " next(self.p)\n",
+ "\n",
+ " def __print_header(self):\n",
+ " print()\n",
+ " format_string = \"0%{: ^\" + str(self.size - 6) + \"}100%\"\n",
+ " print(format_string.format(self.msg))\n",
+ " self.header_printed = True\n",
+ "\n",
+ " def __start_progress(self, maxi):\n",
+ " def print_progress():\n",
+ " # Bresenham's algorithm. Yields the number of dots printed.\n",
+ " # This will always print 100 dots in max invocations.\n",
+ " dx = maxi\n",
+ " dy = self.size\n",
+ " d = dy - dx\n",
+ " for x in range(maxi):\n",
+ " k = 0\n",
+ " while d >= 0:\n",
+ " print('=', end=\"\", flush=True)\n",
+ " k += 1\n",
+ " d -= dx\n",
+ " d += dy\n",
+ " yield k\n",
+ " # Keep yielding the last result if there are too many steps.\n",
+ " while True:\n",
+ " yield k\n",
+ "\n",
+ " return print_progress"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6xgS_8nDDIu8"
+ },
+ "source": [
+ "## 下载并准备 MNIST 数据集"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nbN7rmuF0VFB"
+ },
+ "outputs": [],
+ "source": [
+ "(x_train, train_labels), (x_test, test_labels) = tf.keras.datasets.mnist.load_data()\n",
+ "\n",
+ "train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels))\n",
+ "train_data = train_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1),\n",
+ " tf.one_hot(y, depth=10)))\n",
+ "\n",
+ "BATCH_SIZE = 256\n",
+ "train_data = train_data.batch(BATCH_SIZE, drop_remainder=True)\n",
+ "train_data = train_data.cache()\n",
+ "train_data = train_data.shuffle(5000, reshuffle_each_iteration=True)\n",
+ "\n",
+ "test_data = tf.data.Dataset.from_tensor_slices((x_test, test_labels))\n",
+ "test_data = test_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1),\n",
+ " tf.one_hot(y, depth=10)))\n",
+ "test_data = test_data.batch(10000)\n",
+ "test_data = test_data.cache()\n",
+ "\n",
+ "(one_batch, one_batch_labels) = next(iter(train_data)) # just one batch\n",
+ "(all_test_data, all_test_labels) = next(iter(test_data)) # all in one batch since batch size is 10000"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LuZTo7SM3W_n"
+ },
+ "source": [
+ "## 配置训练\n",
+ "\n",
+ "此笔记本将为演示目的创建并训练一个简单的模型。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3vbKB4yZ3aTL"
+ },
+ "outputs": [],
+ "source": [
+ "# Training hyperparameters.\n",
+ "JAX_EPOCHS = 3\n",
+ "TF_EPOCHS = 7\n",
+ "STEPS_PER_EPOCH = len(train_labels)//BATCH_SIZE\n",
+ "LEARNING_RATE = 0.01\n",
+ "LEARNING_RATE_EXP_DECAY = 0.6\n",
+ "\n",
+ "# The learning rate schedule for JAX (with Optax).\n",
+ "jlr_decay = optax.exponential_decay(LEARNING_RATE, transition_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)\n",
+ "\n",
+ "# THe learning rate schedule for TensorFlow.\n",
+ "tflr_decay = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=LEARNING_RATE, decay_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Od3sMwQxtC34"
+ },
+ "source": [
+ "## 使用 Flax 创建模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-ybqQF2zd2QX"
+ },
+ "outputs": [],
+ "source": [
+ "class ConvModel(flax.linen.Module):\n",
+ "\n",
+ " @flax.linen.compact\n",
+ " def __call__(self, x, train):\n",
+ " x = flax.linen.Conv(features=12, kernel_size=(3,3), padding=\"SAME\", use_bias=False)(x)\n",
+ " x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x)\n",
+ " x = x.reshape((x.shape[0], -1)) # flatten\n",
+ " x = flax.linen.Dense(features=200, use_bias=True)(x)\n",
+ " x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x)\n",
+ " x = flax.linen.Dropout(rate=0.3, deterministic=not train)(x)\n",
+ " x = flax.linen.relu(x)\n",
+ " x = flax.linen.Dense(features=10)(x)\n",
+ " #x = flax.linen.log_softmax(x)\n",
+ " return x\n",
+ "\n",
+ " # JAX differentiation requires a function `f(params, other_state, data, labels)` -> `loss` (as a single number).\n",
+ " # `jax.grad` will differentiate it against the fist argument.\n",
+ " # The user must split trainable and non-trainable variables into `params` and `other_state`.\n",
+ " # Must pass a different RNG key each time for the dropout mask to be different.\n",
+ " def loss(self, params, other_state, rng, data, labels, train):\n",
+ " logits, batch_stats = self.apply({'params': params, **other_state},\n",
+ " data,\n",
+ " mutable=['batch_stats'],\n",
+ " rngs={'dropout': rng},\n",
+ " train=train)\n",
+ " # The loss averaged across the batch dimension.\n",
+ " loss = optax.softmax_cross_entropy(logits, labels).mean()\n",
+ " return loss, batch_stats\n",
+ "\n",
+ " def predict(self, state, data):\n",
+ " logits = self.apply(state, data, train=False) # predict and accuracy disable dropout and use accumulated batch norm stats (train=False)\n",
+ " probabilities = flax.linen.log_softmax(logits)\n",
+ " return probabilities\n",
+ "\n",
+ " def accuracy(self, state, data, labels):\n",
+ " probabilities = self.predict(state, data)\n",
+ " predictions = jnp.argmax(probabilities, axis=-1)\n",
+ " dense_labels = jnp.argmax(labels, axis=-1)\n",
+ " accuracy = jnp.equal(predictions, dense_labels).mean()\n",
+ " return accuracy"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7Cr0FRNFtHN4"
+ },
+ "source": [
+ "## 编写训练步骤函数"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "tmDwApcpgZzw"
+ },
+ "outputs": [],
+ "source": [
+ "# The training step.\n",
+ "@partial(jax.jit, static_argnums=[0]) # this forces jax.jit to recompile for every new model\n",
+ "def train_step(model, state, optimizer_state, rng, data, labels):\n",
+ "\n",
+ " other_state, params = state.pop('params') # differentiate only against 'params' which represents trainable variables\n",
+ " (loss, batch_stats), grads = jax.value_and_grad(model.loss, has_aux=True)(params, other_state, rng, data, labels, train=True)\n",
+ "\n",
+ " updates, optimizer_state = optimizer.update(grads, optimizer_state)\n",
+ " params = optax.apply_updates(params, updates)\n",
+ " new_state = state.copy(add_or_replace={**batch_stats, 'params': params})\n",
+ "\n",
+ " rng, _ = jax.random.split(rng)\n",
+ "\n",
+ " return new_state, optimizer_state, rng, loss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Zr16g6NzV4O9"
+ },
+ "source": [
+ "## 编写训练循环"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "zbl5w-KUV7Qw"
+ },
+ "outputs": [],
+ "source": [
+ "def train(model, state, optimizer_state, train_data, epochs, losses, avg_losses, eval_losses, eval_accuracies):\n",
+ " p = Progress(STEPS_PER_EPOCH)\n",
+ " rng = jax.random.PRNGKey(0)\n",
+ " for epoch in range(epochs):\n",
+ "\n",
+ " # This is where the learning rate schedule state is stored in the optimizer state.\n",
+ " optimizer_step = optimizer_state[1].count\n",
+ "\n",
+ " # Run an epoch of training.\n",
+ " for step, (data, labels) in enumerate(train_data):\n",
+ " p.step(reset=(step==0))\n",
+ " state, optimizer_state, rng, loss = train_step(model, state, optimizer_state, rng, data.numpy(), labels.numpy())\n",
+ " losses.append(loss)\n",
+ " avg_loss = np.mean(losses[-step:])\n",
+ " avg_losses.append(avg_loss)\n",
+ "\n",
+ " # Run one epoch of evals (10,000 test images in a single batch).\n",
+ " other_state, params = state.pop('params')\n",
+ " # Gotcha: must discard modified batch_stats here\n",
+ " eval_loss, _ = model.loss(params, other_state, rng, all_test_data.numpy(), all_test_labels.numpy(), train=False)\n",
+ " eval_losses.append(eval_loss)\n",
+ " eval_accuracy = model.accuracy(state, all_test_data.numpy(), all_test_labels.numpy())\n",
+ " eval_accuracies.append(eval_accuracy)\n",
+ "\n",
+ " print(\"\\nEpoch\", epoch, \"train loss:\", avg_loss, \"eval loss:\", eval_loss, \"eval accuracy\", eval_accuracy, \"lr:\", jlr_decay(optimizer_step))\n",
+ "\n",
+ " return state, optimizer_state"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "DGB3W5g0Wt1H"
+ },
+ "source": [
+ "## 创建模型和优化器(使用 Optax)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "mW5mkmCWtN8W"
+ },
+ "outputs": [],
+ "source": [
+ "# The model.\n",
+ "model = ConvModel()\n",
+ "state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for \"dropout\"\n",
+ "\n",
+ "# The optimizer.\n",
+ "optimizer = optax.adam(learning_rate=jlr_decay) # Gotcha: it does not seem to be possible to pass just a callable as LR, must be an Optax Schedule\n",
+ "optimizer_state = optimizer.init(state['params'])\n",
+ "\n",
+ "losses=[]\n",
+ "avg_losses=[]\n",
+ "eval_losses=[]\n",
+ "eval_accuracies=[]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FJdsKghBNF"
+ },
+ "source": [
+ "## 训练模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nmcofTTBZSIb"
+ },
+ "outputs": [],
+ "source": [
+ "new_state, new_optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS+TF_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "n_20vgvDXB5r"
+ },
+ "outputs": [],
+ "source": [
+ "display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0lT3cdENCBzL"
+ },
+ "source": [
+ "## 部分训练模型\n",
+ "\n",
+ "您很快将在 TensorFlow 中继续训练模型。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "KT-xqj5N7C6L"
+ },
+ "outputs": [],
+ "source": [
+ "model = ConvModel()\n",
+ "state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for \"dropout\"\n",
+ "\n",
+ "# The optimizer.\n",
+ "optimizer = optax.adam(learning_rate=jlr_decay) # LR must be an Optax LR Schedule\n",
+ "optimizer_state = optimizer.init(state['params'])\n",
+ "\n",
+ "losses, avg_losses, eval_losses, eval_accuracies = [], [], [], []"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "oa362HMDbzDE"
+ },
+ "outputs": [],
+ "source": [
+ "state, optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0IyZtUPPCt0y"
+ },
+ "outputs": [],
+ "source": [
+ "display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "uNtlSaOCCumB"
+ },
+ "source": [
+ "## 保存推断所需的内容\n",
+ "\n",
+ "如果您的目标是部署 JAX 模型(以便您可以使用 `model.predict()` 运行推断),只需将其导出到 [SavedModel](https://tensorflow.google.cn/guide/saved_model)。本节演示如何实现这一点。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "O653B3-5H8FL"
+ },
+ "outputs": [],
+ "source": [
+ "# Test data with a different batch size to test polymorphic shapes.\n",
+ "x, y = next(iter(train_data.unbatch().batch(13)))\n",
+ "\n",
+ "m = tf.Module()\n",
+ "# Wrap the JAX state in `tf.Variable` (needed when calling the converted JAX function.\n",
+ "state_vars = tf.nest.map_structure(tf.Variable, state)\n",
+ "# Keep the wrapped state as flat list (needed in TensorFlow fine-tuning).\n",
+ "m.vars = tf.nest.flatten(state_vars)\n",
+ "# Convert the desired JAX function (`model.predict`).\n",
+ "predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=[\"...\", \"(b, 28, 28, 1)\"])\n",
+ "# Wrap the converted function in `tf.function` with the correct `tf.TensorSpec` (necessary for dynamic shapes to work).\n",
+ "@tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)])\n",
+ "def predict(data):\n",
+ " return predict_fn(state_vars, data)\n",
+ "m.predict = predict\n",
+ "tf.saved_model.save(m, \"./\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "8HFx67zStgvo"
+ },
+ "outputs": [],
+ "source": [
+ "# Test the converted function.\n",
+ "print(\"Converted function predictions:\", np.argmax(m.predict(x).numpy(), axis=-1))\n",
+ "# Reload the model.\n",
+ "reloaded_model = tf.saved_model.load(\"./\")\n",
+ "# Test the reloaded converted function (the result should be the same).\n",
+ "print(\"Reloaded function predictions:\", np.argmax(reloaded_model.predict(x).numpy(), axis=-1))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eEk8wv4HJu94"
+ },
+ "source": [
+ "## 保存一切\n",
+ "\n",
+ "如果您的目标是全面导出(如果您计划将模型导入 TensorFlow 进行微调、融合等,这很有用),本节将演示如何保存模型,以便您可以访问以下方法:\n",
+ "\n",
+ "- model.predict\n",
+ "- model.accuracy\n",
+ "- model.loss(包括 train=True/False bool,用于随机失活和 BatchNorm 状态更新的 RNG)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "9mty52pmvDDp"
+ },
+ "outputs": [],
+ "source": [
+ "from collections import abc\n",
+ "\n",
+ "def _fix_frozen(d):\n",
+ " \"\"\"Changes any mappings (e.g. frozendict) back to dict.\"\"\"\n",
+ " if isinstance(d, list):\n",
+ " return [_fix_frozen(v) for v in d]\n",
+ " elif isinstance(d, tuple):\n",
+ " return tuple(_fix_frozen(v) for v in d)\n",
+ " elif not isinstance(d, abc.Mapping):\n",
+ " return d\n",
+ " d = dict(d)\n",
+ " for k, v in d.items():\n",
+ " d[k] = _fix_frozen(v)\n",
+ " return d"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3HEsKNXbCwXw"
+ },
+ "outputs": [],
+ "source": [
+ "class TFModel(tf.Module):\n",
+ " def __init__(self, state, model):\n",
+ " super().__init__()\n",
+ "\n",
+ " # Special care needed for the train=True/False parameter in the loss\n",
+ " @jax.jit\n",
+ " def loss_with_train_bool(state, rng, data, labels, train):\n",
+ " other_state, params = state.pop('params')\n",
+ " loss, batch_stats = jax.lax.cond(train,\n",
+ " lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=True),\n",
+ " lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=False),\n",
+ " state, data, labels)\n",
+ " # must use JAX to split the RNG, therefore, must do it in a @jax.jit function\n",
+ " new_rng, _ = jax.random.split(rng)\n",
+ " return loss, batch_stats, new_rng\n",
+ "\n",
+ " self.state_vars = tf.nest.map_structure(tf.Variable, state)\n",
+ " self.vars = tf.nest.flatten(self.state_vars)\n",
+ " self.jax_rng = tf.Variable(jax.random.PRNGKey(0))\n",
+ "\n",
+ " self.loss_fn = jax2tf.convert(loss_with_train_bool, polymorphic_shapes=[\"...\", \"...\", \"(b, 28, 28, 1)\", \"(b, 10)\", \"...\"])\n",
+ " self.accuracy_fn = jax2tf.convert(model.accuracy, polymorphic_shapes=[\"...\", \"(b, 28, 28, 1)\", \"(b, 10)\"])\n",
+ " self.predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=[\"...\", \"(b, 28, 28, 1)\"])\n",
+ "\n",
+ " # Must specify TensorSpec manually for variable batch size to work\n",
+ " @tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)])\n",
+ " def predict(self, data):\n",
+ " # Make sure the TfModel.predict function implicitly use self.state_vars and not the JAX state directly\n",
+ " # otherwise, all model weights would be embedded in the TF graph as constants.\n",
+ " return self.predict_fn(self.state_vars, data)\n",
+ "\n",
+ " @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),\n",
+ " tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n",
+ " autograph=False)\n",
+ " def train_loss(self, data, labels):\n",
+ " loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, True)\n",
+ " # update batch norm stats\n",
+ " flat_vars = tf.nest.flatten(self.state_vars['batch_stats'])\n",
+ " flat_values = tf.nest.flatten(batch_stats['batch_stats'])\n",
+ " for var, val in zip(flat_vars, flat_values):\n",
+ " var.assign(val)\n",
+ " # update RNG\n",
+ " self.jax_rng.assign(new_rng)\n",
+ " return loss\n",
+ "\n",
+ " @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),\n",
+ " tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n",
+ " autograph=False)\n",
+ " def eval_loss(self, data, labels):\n",
+ " loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, False)\n",
+ " return loss\n",
+ "\n",
+ " @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),\n",
+ " tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n",
+ " autograph=False)\n",
+ " def accuracy(self, data, labels):\n",
+ " return self.accuracy_fn(self.state_vars, data, labels)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "znJrAVpcxO9u"
+ },
+ "outputs": [],
+ "source": [
+ "# Instantiate the model.\n",
+ "tf_model = TFModel(state, model)\n",
+ "\n",
+ "# Save the model.\n",
+ "tf.saved_model.save(tf_model, \"./\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Y02DHEwTjNzV"
+ },
+ "source": [
+ "## 重新加载模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "i75yS3v2jPpM"
+ },
+ "outputs": [],
+ "source": [
+ "reloaded_model = tf.saved_model.load(\"./\")\n",
+ "\n",
+ "# Test if it works and that the batch size is indeed variable.\n",
+ "x,y = next(iter(train_data.unbatch().batch(13)))\n",
+ "print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1))\n",
+ "x,y = next(iter(train_data.unbatch().batch(20)))\n",
+ "print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1))\n",
+ "\n",
+ "print(reloaded_model.accuracy(one_batch, one_batch_labels))\n",
+ "print(reloaded_model.accuracy(all_test_data, all_test_labels))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "DiwEAwQmlx1x"
+ },
+ "source": [
+ "## 在 TensorFlow 中继续训练转换后的 JAX 模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "MubFcO_jl2vE"
+ },
+ "outputs": [],
+ "source": [
+ "optimizer = tf.keras.optimizers.Adam(learning_rate=tflr_decay)\n",
+ "\n",
+ "# Set the iteration step for the learning rate to resume from where it left off in JAX.\n",
+ "optimizer.iterations.assign(len(eval_losses)*STEPS_PER_EPOCH)\n",
+ "\n",
+ "p = Progress(STEPS_PER_EPOCH)\n",
+ "\n",
+ "for epoch in range(JAX_EPOCHS, JAX_EPOCHS+TF_EPOCHS):\n",
+ "\n",
+ " # This is where the learning rate schedule state is stored in the optimizer state.\n",
+ " optimizer_step = optimizer.iterations\n",
+ "\n",
+ " for step, (data, labels) in enumerate(train_data):\n",
+ " p.step(reset=(step==0))\n",
+ " with tf.GradientTape() as tape:\n",
+ " #loss = reloaded_model.loss(data, labels, True)\n",
+ " loss = reloaded_model.train_loss(data, labels)\n",
+ " grads = tape.gradient(loss, reloaded_model.vars)\n",
+ " optimizer.apply_gradients(zip(grads, reloaded_model.vars))\n",
+ " losses.append(loss)\n",
+ " avg_loss = np.mean(losses[-step:])\n",
+ " avg_losses.append(avg_loss)\n",
+ "\n",
+ " eval_loss = reloaded_model.eval_loss(all_test_data.numpy(), all_test_labels.numpy()).numpy()\n",
+ " eval_losses.append(eval_loss)\n",
+ " eval_accuracy = reloaded_model.accuracy(all_test_data.numpy(), all_test_labels.numpy()).numpy()\n",
+ " eval_accuracies.append(eval_accuracy)\n",
+ "\n",
+ " print(\"\\nEpoch\", epoch, \"train loss:\", avg_loss, \"eval loss:\", eval_loss, \"eval accuracy\", eval_accuracy, \"lr:\", tflr_decay(optimizer.iterations).numpy())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "50V1FSmI6UTk"
+ },
+ "outputs": [],
+ "source": [
+ "display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=2*STEPS_PER_EPOCH)\n",
+ "\n",
+ "# The loss takes a hit when the training restarts, but does not go back to random levels.\n",
+ "# This is likely caused by the optimizer momentum being reinitialized."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "L7lSziW0K0ny"
+ },
+ "source": [
+ "## 后续步骤\n",
+ "\n",
+ "您可以在包含详细指南和示例的文档网站上详细了解 [JAX](https://jax.readthedocs.io/en/latest/index.html) 和 [Flax](https://flax.readthedocs.io/en/latest)。如果您刚接触 JAX,请务必浏览 [JAX 101 教程](https://jax.readthedocs.io/en/latest/jax-101/index.html),并查看 [Flax 快速入门](https://flax.readthedocs.io/en/latest/getting_started.html)。要详细了解如何将 JAX 模型转换为 TensorFlow 格式,请查看 GitHub 上的 [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf) 实用工具。如果您有兴趣将 JAX 模型转换为使用 TensorFlow.js 在浏览器中运行,请访问[使用 TensorFlow.js 在 Web 上运行 JAX](https://blog.tensorflow.org/2022/08/jax-on-web-with-tensorflowjs.html)。如果您想准备 JAX 模型以在 TensorFLow Lite 中运行,请访问 [TFLite 的 JAX 模型转换](https://tensorflow.google.cn/lite/examples/jax_conversion/overview)指南。"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "name": "jax2tf.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/site/zh-cn/guide/keras/preprocessing_layers.ipynb b/site/zh-cn/guide/keras/preprocessing_layers.ipynb
new file mode 100644
index 0000000000..467cfb6822
--- /dev/null
+++ b/site/zh-cn/guide/keras/preprocessing_layers.ipynb
@@ -0,0 +1,751 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "b518b04cbfe0"
+ },
+ "source": [
+ "##### Copyright 2020 The TensorFlow Authors."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "906e07f6e562"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6e083398b477"
+ },
+ "source": [
+ "# 使用预处理层"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "64010bd23c2e"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "b1d403f04693"
+ },
+ "source": [
+ "## Keras 预处理\n",
+ "\n",
+ "Keras 预处理层 API 可供开发者构建 Keras 原生输入处理流水线。这些输入处理流水线可在非 Keras 工作流中用作独立预处理代码,直接与 Keras 模型结合,并作为 Keras SavedModel 的一部分导出。\n",
+ "\n",
+ "借助 Keras 预处理层,您可以构建和导出真正端到端的模型:接受原始图像或原始结构化数据作为输入的模型;自行处理特征归一化或特征值索引的模型。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "313360fa9024"
+ },
+ "source": [
+ "## 可用预处理\n",
+ "\n",
+ "### 文本预处理\n",
+ "\n",
+ "- `tf.keras.layers.TextVectorization`:将原始字符串转换为可由 `Embedding` 层或 `Dense` 层读取的编码表示法。\n",
+ "\n",
+ "### 数值特征预处理\n",
+ "\n",
+ "- `tf.keras.layers.Normalization`:对输入特征执行逐特征归一化。\n",
+ "- `tf.keras.layers.Discretization`:将连续数值特征转换为整数分类特征。\n",
+ "\n",
+ "### 分类特征预处理\n",
+ "\n",
+ "- `tf.keras.layers.CategoryEncoding`:将整数分类特征转换为独热、多热或计数密集表示法。\n",
+ "- `tf.keras.layers.Hashing`:执行分类特征哈希,也称为“哈希技巧”(hashing trick)。\n",
+ "- `tf.keras.layers.StringLookup`:将字符串分类值转换为可由 `Embedding` 层或 `Dense` 层读取的编码表示法。\n",
+ "- `tf.keras.layers.IntegerLookup`:将整数分类值转换为可由 `Embedding` 层或 `Dense` 层读取的编码表示法。\n",
+ "\n",
+ "### 图像预处理\n",
+ "\n",
+ "这些层用于标准化图像模型的输入。\n",
+ "\n",
+ "- `tf.keras.layers.Resizing`:将一批图像的大小调整为目标大小。\n",
+ "- `tf.keras.layers.Rescaling`:重新缩放和偏移一批图像的值(例如,从 `[0, 255]` 范围内的输入变为 `[0, 1]` 范围内的输入)。\n",
+ "- `tf.keras.layers.CenterCrop`:返回一批图像的中心裁剪。\n",
+ "\n",
+ "### 图像数据增强\n",
+ "\n",
+ "以下层会对一批图像应用随机增强转换。它们仅在训练期间有效。\n",
+ "\n",
+ "- `tf.keras.layers.RandomCrop`\n",
+ "- `tf.keras.layers.RandomFlip`\n",
+ "- `tf.keras.layers.RandomTranslation`\n",
+ "- `tf.keras.layers.RandomRotation`\n",
+ "- `tf.keras.layers.RandomZoom`\n",
+ "- `tf.keras.layers.RandomHeight`\n",
+ "- `tf.keras.layers.RandomWidth`\n",
+ "- `tf.keras.layers.RandomContrast`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "c923e41fb1b4"
+ },
+ "source": [
+ "## `adapt()` 方法\n",
+ "\n",
+ "一些预处理层具有可基于训练数据样本计算的内部状态。以下是有状态预处理层的列表:\n",
+ "\n",
+ "- `TextVectorization`:保存字符串词例和整数索引之间的映射。\n",
+ "- `StringLookup` 和 `IntegerLookup`:保存输入值和整数索引之间的映射。\n",
+ "- `Normalization`:保存特征的平均值和标准差。\n",
+ "- `Discretization`:保存值桶边界相关信息。\n",
+ "\n",
+ "至关重要的是,这些层**不可训练**。它们的状态不是在训练期间设定的;必须在**训练之前**设置状态,方法是通过预先计算的常量对其进行初始化,或者基于数据对其进行“调整”。\n",
+ "\n",
+ "您可以通过如下 `adapt()` 方法将预处理层公开给训练数据以设置其状态:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4cac6bd80812"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import tensorflow as tf\n",
+ "from tensorflow.keras import layers\n",
+ "\n",
+ "data = np.array([[0.1, 0.2, 0.3], [0.8, 0.9, 1.0], [1.5, 1.6, 1.7],])\n",
+ "layer = layers.Normalization()\n",
+ "layer.adapt(data)\n",
+ "normalized_data = layer(data)\n",
+ "\n",
+ "print(\"Features mean: %.2f\" % (normalized_data.numpy().mean()))\n",
+ "print(\"Features std: %.2f\" % (normalized_data.numpy().std()))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "d43b8246b8a3"
+ },
+ "source": [
+ "`adapt()` 方法接受 Numpy 数组或 `tf.data.Dataset` 对象。对于 `StringLookup` 和 `TextVectorization`,您还可以传递字符串列表:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "48d95713348a"
+ },
+ "outputs": [],
+ "source": [
+ "data = [\n",
+ " \"ξεῖν᾽, ἦ τοι μὲν ὄνειροι ἀμήχανοι ἀκριτόμυθοι\",\n",
+ " \"γίγνοντ᾽, οὐδέ τι πάντα τελείεται ἀνθρώποισι.\",\n",
+ " \"δοιαὶ γάρ τε πύλαι ἀμενηνῶν εἰσὶν ὀνείρων:\",\n",
+ " \"αἱ μὲν γὰρ κεράεσσι τετεύχαται, αἱ δ᾽ ἐλέφαντι:\",\n",
+ " \"τῶν οἳ μέν κ᾽ ἔλθωσι διὰ πριστοῦ ἐλέφαντος,\",\n",
+ " \"οἵ ῥ᾽ ἐλεφαίρονται, ἔπε᾽ ἀκράαντα φέροντες:\",\n",
+ " \"οἱ δὲ διὰ ξεστῶν κεράων ἔλθωσι θύραζε,\",\n",
+ " \"οἵ ῥ᾽ ἔτυμα κραίνουσι, βροτῶν ὅτε κέν τις ἴδηται.\",\n",
+ "]\n",
+ "layer = layers.TextVectorization()\n",
+ "layer.adapt(data)\n",
+ "vectorized_text = layer(data)\n",
+ "print(vectorized_text)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7619914dfb40"
+ },
+ "source": [
+ "此外,自适应层始终公开一个可以通过构造函数参数或权重赋值直接设置状态的选项。如果预期的状态值在构造层时已知,或者是在 `adapt()` 调用之外计算的,则可以在不依赖层的内部计算的情况下对其进行设置。例如,如果 `TextVectorization`、`StringLookup` 或 `IntegerLookup` 层的外部词汇文件已存在,则可以通过在层的构造函数参数中传递词汇文件的路径来将这些文件直接加载到查找表中。\n",
+ "\n",
+ "下面是我们使用预先计算的词汇实例化 `StringLookup` 层的示例:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "9df56efc7f3b"
+ },
+ "outputs": [],
+ "source": [
+ "vocab = [\"a\", \"b\", \"c\", \"d\"]\n",
+ "data = tf.constant([[\"a\", \"c\", \"d\"], [\"d\", \"z\", \"b\"]])\n",
+ "layer = layers.StringLookup(vocabulary=vocab)\n",
+ "vectorized_data = layer(data)\n",
+ "print(vectorized_data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "49cbfe135b00"
+ },
+ "source": [
+ "## 在模型之前或模型内部预处理数据\n",
+ "\n",
+ "您可以通过以下两种方式使用预处理层:\n",
+ "\n",
+ "**选项 1**:使它们成为模型的一部分,如下所示:\n",
+ "\n",
+ "```python\n",
+ "inputs = keras.Input(shape=input_shape)\n",
+ "x = preprocessing_layer(inputs)\n",
+ "outputs = rest_of_the_model(x)\n",
+ "model = keras.Model(inputs, outputs)\n",
+ "```\n",
+ "\n",
+ "使用此选项,预处理将在设备上与模型执行的其余部分同步进行,这意味着它将受益于 GPU 加速。如果您在 GPU 上进行训练,那么这是 `Normalization` 层以及所有图像预处理和数据增强层的最佳选择。\n",
+ "\n",
+ "**选项 2**:将它应用到您的 `tf.data.Dataset`,以获得可生成批量预处理数据的数据集,如下所示:\n",
+ "\n",
+ "```python\n",
+ "dataset = dataset.map(lambda x, y: (preprocessing_layer(x), y))\n",
+ "```\n",
+ "\n",
+ "使用此选项,您的预处理将在 CPU 上异步进行,并在进入模型之前进行缓存。此外,如果您对数据集调用 `dataset.prefetch(tf.data.AUTOTUNE)`,则预处理将与训练同时有效进行:\n",
+ "\n",
+ "```python\n",
+ "dataset = dataset.map(lambda x, y: (preprocessing_layer(x), y))\n",
+ "dataset = dataset.prefetch(tf.data.AUTOTUNE)\n",
+ "model.fit(dataset, ...)\n",
+ "```\n",
+ "\n",
+ "这是 `TextVectorization` 和所有结构化数据预处理层的最佳选择。如果您在 CPU 上进行训练并且使用图像预处理层,那么这同样是一个不错的选择。\n",
+ "\n",
+ "**在 TPU 上运行时,应始终将预处理层置于 `tf.data` 流水线中**(`Normalization` 和 `Rescaling` 除外,由于第一层是图像模型,它们在 TPU 上运行良好且被普遍使用)。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "32f6d2a104b7"
+ },
+ "source": [
+ "## 推断时在模型内部进行预处理的好处\n",
+ "\n",
+ "即使您采用选项 2,您稍后也可能需要导出包含预处理层的仅推断端到端模型。这样做的主要好处是**它使您的模型具有可移植性**,并且**有助于降低[训练/应用偏差](https://developers.google.com/machine-learning/guides/rules-of-ml#training-serving_skew)**。\n",
+ "\n",
+ "当所有数据预处理均为模型的一部分时,其他人可以加载和使用您的模型,而无需了解每个特征预计会如何编码和归一化。您的推断模型将能够处理原始图像或原始结构化数据,并且不需要模型的用户了解诸如以下详细信息: 用于文本的词例化方案、用于分类特征的索引方案、图像像素值是归一化为 `[-1, +1]` 还是 `[0, 1]` 等。如果您要将模型导出到其他运行时(例如 TensorFlow.js),那么这尤为强大:您不必在 JavaScript 中重新实现预处理流水线。\n",
+ "\n",
+ "如果您最初将预处理层置于 `tf.data` 流水线内,则可以导出打包有预处理的推断模型。只需实例化一个链接着预处理层和训练模型的新模型即可:\n",
+ "\n",
+ "```python\n",
+ "inputs = keras.Input(shape=input_shape)\n",
+ "x = preprocessing_layer(inputs)\n",
+ "outputs = training_model(x)\n",
+ "inference_model = keras.Model(inputs, outputs)\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "b41b381d48d4"
+ },
+ "source": [
+ "## 快速秘诀\n",
+ "\n",
+ "### 图像数据增强\n",
+ "\n",
+ "请注意,图像数据增强层仅在训练期间有效(类似于 `Dropout` 层)。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "a3793692e983"
+ },
+ "outputs": [],
+ "source": [
+ "from tensorflow import keras\n",
+ "from tensorflow.keras import layers\n",
+ "\n",
+ "# Create a data augmentation stage with horizontal flipping, rotations, zooms\n",
+ "data_augmentation = keras.Sequential(\n",
+ " [\n",
+ " layers.RandomFlip(\"horizontal\"),\n",
+ " layers.RandomRotation(0.1),\n",
+ " layers.RandomZoom(0.1),\n",
+ " ]\n",
+ ")\n",
+ "\n",
+ "# Load some data\n",
+ "(x_train, y_train), _ = keras.datasets.cifar10.load_data()\n",
+ "input_shape = x_train.shape[1:]\n",
+ "classes = 10\n",
+ "\n",
+ "# Create a tf.data pipeline of augmented images (and their labels)\n",
+ "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
+ "train_dataset = train_dataset.batch(16).map(lambda x, y: (data_augmentation(x), y))\n",
+ "\n",
+ "\n",
+ "# Create a model and train it on the augmented image data\n",
+ "inputs = keras.Input(shape=input_shape)\n",
+ "x = layers.Rescaling(1.0 / 255)(inputs) # Rescale inputs\n",
+ "outputs = keras.applications.ResNet50( # Add the rest of the model\n",
+ " weights=None, input_shape=input_shape, classes=classes\n",
+ ")(x)\n",
+ "model = keras.Model(inputs, outputs)\n",
+ "model.compile(optimizer=\"rmsprop\", loss=\"sparse_categorical_crossentropy\")\n",
+ "model.fit(train_dataset, steps_per_epoch=5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "51d369f0310f"
+ },
+ "source": [
+ "您可以在[从零开始进行图像分类](https://keras.io/examples/vision/image_classification_from_scratch/)示例中查看类似设置。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a79a1c48b2b7"
+ },
+ "source": [
+ "### 归一化数值特征"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "9cc2607a45c8"
+ },
+ "outputs": [],
+ "source": [
+ "# Load some data\n",
+ "(x_train, y_train), _ = keras.datasets.cifar10.load_data()\n",
+ "x_train = x_train.reshape((len(x_train), -1))\n",
+ "input_shape = x_train.shape[1:]\n",
+ "classes = 10\n",
+ "\n",
+ "# Create a Normalization layer and set its internal state using the training data\n",
+ "normalizer = layers.Normalization()\n",
+ "normalizer.adapt(x_train)\n",
+ "\n",
+ "# Create a model that include the normalization layer\n",
+ "inputs = keras.Input(shape=input_shape)\n",
+ "x = normalizer(inputs)\n",
+ "outputs = layers.Dense(classes, activation=\"softmax\")(x)\n",
+ "model = keras.Model(inputs, outputs)\n",
+ "\n",
+ "# Train the model\n",
+ "model.compile(optimizer=\"adam\", loss=\"sparse_categorical_crossentropy\")\n",
+ "model.fit(x_train, y_train)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "62685d477010"
+ },
+ "source": [
+ "### 通过独热编码进行字符串分类特征编码"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ae0d2b0405f1"
+ },
+ "outputs": [],
+ "source": [
+ "# Define some toy data\n",
+ "data = tf.constant([[\"a\"], [\"b\"], [\"c\"], [\"b\"], [\"c\"], [\"a\"]])\n",
+ "\n",
+ "# Use StringLookup to build an index of the feature values and encode output.\n",
+ "lookup = layers.StringLookup(output_mode=\"one_hot\")\n",
+ "lookup.adapt(data)\n",
+ "\n",
+ "# Convert new test data (which includes unknown feature values)\n",
+ "test_data = tf.constant([[\"a\"], [\"b\"], [\"c\"], [\"d\"], [\"e\"], [\"\"]])\n",
+ "encoded_data = lookup(test_data)\n",
+ "print(encoded_data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "686aeda532f5"
+ },
+ "source": [
+ "请注意,此处的索引 0 为词汇之外的值(`adapt()` 期间未出现的值)保留。\n",
+ "\n",
+ "您可以在[从零开始进行结构化数据分类](https://keras.io/examples/structured_data/structured_data_classification_from_scratch/)示例中查看 `StringLookup` 的实际应用。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "dc8af3e290df"
+ },
+ "source": [
+ "### 通过独热编码进行整数分类特征编码"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "75f3d6af4522"
+ },
+ "outputs": [],
+ "source": [
+ "# Define some toy data\n",
+ "data = tf.constant([[10], [20], [20], [10], [30], [0]])\n",
+ "\n",
+ "# Use IntegerLookup to build an index of the feature values and encode output.\n",
+ "lookup = layers.IntegerLookup(output_mode=\"one_hot\")\n",
+ "lookup.adapt(data)\n",
+ "\n",
+ "# Convert new test data (which includes unknown feature values)\n",
+ "test_data = tf.constant([[10], [10], [20], [50], [60], [0]])\n",
+ "encoded_data = lookup(test_data)\n",
+ "print(encoded_data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "da5a6be487be"
+ },
+ "source": [
+ "请注意,索引 0 为缺失值(您应将其指定为值 0)保留,索引 1 为词汇之外的值(`adapt()` 期间未出现的值)保留。您可以使用 `IntegerLookup` 的 `mask_token` 和 `oov_token` 构造函数参数进行配置。\n",
+ "\n",
+ "您可以在[从零开始进行结构化数据分类](https://keras.io/examples/structured_data/structured_data_classification_from_scratch/)示例中看到 `IntegerLookup` 的实际应用。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8fbfaa6ab3e2"
+ },
+ "source": [
+ "### 对整数分类特征应用哈希技巧\n",
+ "\n",
+ "如果您拥有可以接受许多不同值(约 10e3 或更高次方)的分类特征,其中每个值仅在数据中出现几次,那么对特征值进行索引和独热编码就变得不切实际且低效。相反,应用“哈希技巧”可能是一个好主意:将值散列到固定大小的向量。这使得特征空间的大小易于管理,并且无需显式索引。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "8f6c1f84c43c"
+ },
+ "outputs": [],
+ "source": [
+ "# Sample data: 10,000 random integers with values between 0 and 100,000\n",
+ "data = np.random.randint(0, 100000, size=(10000, 1))\n",
+ "\n",
+ "# Use the Hashing layer to hash the values to the range [0, 64]\n",
+ "hasher = layers.Hashing(num_bins=64, salt=1337)\n",
+ "\n",
+ "# Use the CategoryEncoding layer to multi-hot encode the hashed values\n",
+ "encoder = layers.CategoryEncoding(num_tokens=64, output_mode=\"multi_hot\")\n",
+ "encoded_data = encoder(hasher(data))\n",
+ "print(encoded_data.shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "df69b434d327"
+ },
+ "source": [
+ "### 将文本编码为词例索引序列\n",
+ "\n",
+ "这是预处理要传递到 `Embedding` 层的文本时应采用的方式。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "361b561bc88b"
+ },
+ "outputs": [],
+ "source": [
+ "# Define some text data to adapt the layer\n",
+ "adapt_data = tf.constant(\n",
+ " [\n",
+ " \"The Brain is wider than the Sky\",\n",
+ " \"For put them side by side\",\n",
+ " \"The one the other will contain\",\n",
+ " \"With ease and You beside\",\n",
+ " ]\n",
+ ")\n",
+ "\n",
+ "# Create a TextVectorization layer\n",
+ "text_vectorizer = layers.TextVectorization(output_mode=\"int\")\n",
+ "# Index the vocabulary via `adapt()`\n",
+ "text_vectorizer.adapt(adapt_data)\n",
+ "\n",
+ "# Try out the layer\n",
+ "print(\n",
+ " \"Encoded text:\\n\", text_vectorizer([\"The Brain is deeper than the sea\"]).numpy(),\n",
+ ")\n",
+ "\n",
+ "# Create a simple model\n",
+ "inputs = keras.Input(shape=(None,), dtype=\"int64\")\n",
+ "x = layers.Embedding(input_dim=text_vectorizer.vocabulary_size(), output_dim=16)(inputs)\n",
+ "x = layers.GRU(8)(x)\n",
+ "outputs = layers.Dense(1)(x)\n",
+ "model = keras.Model(inputs, outputs)\n",
+ "\n",
+ "# Create a labeled dataset (which includes unknown tokens)\n",
+ "train_dataset = tf.data.Dataset.from_tensor_slices(\n",
+ " ([\"The Brain is deeper than the sea\", \"for if they are held Blue to Blue\"], [1, 0])\n",
+ ")\n",
+ "\n",
+ "# Preprocess the string inputs, turning them into int sequences\n",
+ "train_dataset = train_dataset.batch(2).map(lambda x, y: (text_vectorizer(x), y))\n",
+ "# Train the model on the int sequences\n",
+ "print(\"\\nTraining model...\")\n",
+ "model.compile(optimizer=\"rmsprop\", loss=\"mse\")\n",
+ "model.fit(train_dataset)\n",
+ "\n",
+ "# For inference, you can export a model that accepts strings as input\n",
+ "inputs = keras.Input(shape=(1,), dtype=\"string\")\n",
+ "x = text_vectorizer(inputs)\n",
+ "outputs = model(x)\n",
+ "end_to_end_model = keras.Model(inputs, outputs)\n",
+ "\n",
+ "# Call the end-to-end model on test data (which includes unknown tokens)\n",
+ "print(\"\\nCalling end-to-end model on test string...\")\n",
+ "test_data = tf.constant([\"The one the other will absorb\"])\n",
+ "test_output = end_to_end_model(test_data)\n",
+ "print(\"Model output:\", test_output)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "e725dbcae3e4"
+ },
+ "source": [
+ "您可以在示例从头开始进行文本分类中查看 `TextVectorization` 层与 Embedding
模式组合的实际使用情况。\n",
+ "\n",
+ "请注意,在训练此类模型时,为了获得最佳性能,您应始终使用 `TextVectorization` 层作为输入流水线的一部分。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "28c2f2ff61fb"
+ },
+ "source": [
+ "### 通过多热编码将文本编码为 ngram 的密集矩阵\n",
+ "\n",
+ "这是预处理要传递到 `Dense` 层的文本时应采用的方式。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7bae1c223cd8"
+ },
+ "outputs": [],
+ "source": [
+ "# Define some text data to adapt the layer\n",
+ "adapt_data = tf.constant(\n",
+ " [\n",
+ " \"The Brain is wider than the Sky\",\n",
+ " \"For put them side by side\",\n",
+ " \"The one the other will contain\",\n",
+ " \"With ease and You beside\",\n",
+ " ]\n",
+ ")\n",
+ "# Instantiate TextVectorization with \"multi_hot\" output_mode\n",
+ "# and ngrams=2 (index all bigrams)\n",
+ "text_vectorizer = layers.TextVectorization(output_mode=\"multi_hot\", ngrams=2)\n",
+ "# Index the bigrams via `adapt()`\n",
+ "text_vectorizer.adapt(adapt_data)\n",
+ "\n",
+ "# Try out the layer\n",
+ "print(\n",
+ " \"Encoded text:\\n\", text_vectorizer([\"The Brain is deeper than the sea\"]).numpy(),\n",
+ ")\n",
+ "\n",
+ "# Create a simple model\n",
+ "inputs = keras.Input(shape=(text_vectorizer.vocabulary_size(),))\n",
+ "outputs = layers.Dense(1)(inputs)\n",
+ "model = keras.Model(inputs, outputs)\n",
+ "\n",
+ "# Create a labeled dataset (which includes unknown tokens)\n",
+ "train_dataset = tf.data.Dataset.from_tensor_slices(\n",
+ " ([\"The Brain is deeper than the sea\", \"for if they are held Blue to Blue\"], [1, 0])\n",
+ ")\n",
+ "\n",
+ "# Preprocess the string inputs, turning them into int sequences\n",
+ "train_dataset = train_dataset.batch(2).map(lambda x, y: (text_vectorizer(x), y))\n",
+ "# Train the model on the int sequences\n",
+ "print(\"\\nTraining model...\")\n",
+ "model.compile(optimizer=\"rmsprop\", loss=\"mse\")\n",
+ "model.fit(train_dataset)\n",
+ "\n",
+ "# For inference, you can export a model that accepts strings as input\n",
+ "inputs = keras.Input(shape=(1,), dtype=\"string\")\n",
+ "x = text_vectorizer(inputs)\n",
+ "outputs = model(x)\n",
+ "end_to_end_model = keras.Model(inputs, outputs)\n",
+ "\n",
+ "# Call the end-to-end model on test data (which includes unknown tokens)\n",
+ "print(\"\\nCalling end-to-end model on test string...\")\n",
+ "test_data = tf.constant([\"The one the other will absorb\"])\n",
+ "test_output = end_to_end_model(test_data)\n",
+ "print(\"Model output:\", test_output)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "336a4d3426ed"
+ },
+ "source": [
+ "### 通过 TF-IDF 加权将文本编码为 ngram 的密集矩阵\n",
+ "\n",
+ "这是在将文本传递到 `Dense` 层之前对其进行预处理的另一种方式。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "5b6c0fec928e"
+ },
+ "outputs": [],
+ "source": [
+ "# Define some text data to adapt the layer\n",
+ "adapt_data = tf.constant(\n",
+ " [\n",
+ " \"The Brain is wider than the Sky\",\n",
+ " \"For put them side by side\",\n",
+ " \"The one the other will contain\",\n",
+ " \"With ease and You beside\",\n",
+ " ]\n",
+ ")\n",
+ "# Instantiate TextVectorization with \"tf-idf\" output_mode\n",
+ "# (multi-hot with TF-IDF weighting) and ngrams=2 (index all bigrams)\n",
+ "text_vectorizer = layers.TextVectorization(output_mode=\"tf-idf\", ngrams=2)\n",
+ "# Index the bigrams and learn the TF-IDF weights via `adapt()`\n",
+ "\n",
+ "with tf.device(\"CPU\"):\n",
+ " # A bug that prevents this from running on GPU for now.\n",
+ " text_vectorizer.adapt(adapt_data)\n",
+ "\n",
+ "# Try out the layer\n",
+ "print(\n",
+ " \"Encoded text:\\n\", text_vectorizer([\"The Brain is deeper than the sea\"]).numpy(),\n",
+ ")\n",
+ "\n",
+ "# Create a simple model\n",
+ "inputs = keras.Input(shape=(text_vectorizer.vocabulary_size(),))\n",
+ "outputs = layers.Dense(1)(inputs)\n",
+ "model = keras.Model(inputs, outputs)\n",
+ "\n",
+ "# Create a labeled dataset (which includes unknown tokens)\n",
+ "train_dataset = tf.data.Dataset.from_tensor_slices(\n",
+ " ([\"The Brain is deeper than the sea\", \"for if they are held Blue to Blue\"], [1, 0])\n",
+ ")\n",
+ "\n",
+ "# Preprocess the string inputs, turning them into int sequences\n",
+ "train_dataset = train_dataset.batch(2).map(lambda x, y: (text_vectorizer(x), y))\n",
+ "# Train the model on the int sequences\n",
+ "print(\"\\nTraining model...\")\n",
+ "model.compile(optimizer=\"rmsprop\", loss=\"mse\")\n",
+ "model.fit(train_dataset)\n",
+ "\n",
+ "# For inference, you can export a model that accepts strings as input\n",
+ "inputs = keras.Input(shape=(1,), dtype=\"string\")\n",
+ "x = text_vectorizer(inputs)\n",
+ "outputs = model(x)\n",
+ "end_to_end_model = keras.Model(inputs, outputs)\n",
+ "\n",
+ "# Call the end-to-end model on test data (which includes unknown tokens)\n",
+ "print(\"\\nCalling end-to-end model on test string...\")\n",
+ "test_data = tf.constant([\"The one the other will absorb\"])\n",
+ "test_output = end_to_end_model(test_data)\n",
+ "print(\"Model output:\", test_output)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "143ce01c5558"
+ },
+ "source": [
+ "## 重要问题\n",
+ "\n",
+ "### 处理包含非常大的词汇的查找层\n",
+ "\n",
+ "您可能会在 `TextVectorization`、`StringLookup` 层或 `IntegerLookup` 层中处理非常大的词汇。通常,大于 500MB 的词汇就会被视为“非常大”。\n",
+ "\n",
+ "在这种情况下,为了获得最佳性能,您应避免使用 `adapt()`。相反,应提前预先计算您的词汇(可使用 Apache Beam 或 TF Transform 来实现)并将其存储在文件中。然后,在构建时将文件路径作为 `vocabulary` 参数传递,以将词汇加载到层中。\n",
+ "\n",
+ "### 在 TPU pod 上或与 `ParameterServerStrategy` 一起使用查找层。\n",
+ "\n",
+ "有一个未解决的问题,它会导致在 TPU pod 上或通过 `ParameterServerStrategy` 在多台计算机上进行训练时,使用 `TextVectorization`、`StringLookup` 或 `IntegerLookup` 层时出现性能下降。该问题预计将在 TensorFlow 2.7 中得到修正。"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "preprocessing_layers.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/site/zh-cn/guide/keras/sequential_model.ipynb b/site/zh-cn/guide/keras/sequential_model.ipynb
new file mode 100644
index 0000000000..7690564874
--- /dev/null
+++ b/site/zh-cn/guide/keras/sequential_model.ipynb
@@ -0,0 +1,686 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "b518b04cbfe0"
+ },
+ "source": [
+ "##### Copyright 2020 The TensorFlow Authors."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "906e07f6e562"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3e19705694d6"
+ },
+ "source": [
+ "# 序贯模型"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "defbb10e8ae3"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8d4ac441b1fc"
+ },
+ "source": [
+ "## 安装"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0472bf67b2bf"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "from tensorflow import keras\n",
+ "from tensorflow.keras import layers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "80f7713c6b92"
+ },
+ "source": [
+ "## 何时使用序贯模型\n",
+ "\n",
+ "`Sequential` 模型适用于**简单的层堆栈**,其中每个层都**恰好有一个输入张量和一个输出张量**。\n",
+ "\n",
+ "以下 `Sequential` 模型(仅作为示意):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "f536515be229"
+ },
+ "outputs": [],
+ "source": [
+ "# Define Sequential model with 3 layers\n",
+ "model = keras.Sequential(\n",
+ " [\n",
+ " layers.Dense(2, activation=\"relu\", name=\"layer1\"),\n",
+ " layers.Dense(3, activation=\"relu\", name=\"layer2\"),\n",
+ " layers.Dense(4, name=\"layer3\"),\n",
+ " ]\n",
+ ")\n",
+ "# Call model on a test input\n",
+ "x = tf.ones((3, 3))\n",
+ "y = model(x)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7d81502d9753"
+ },
+ "source": [
+ "等效于此函数:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7059a8890f72"
+ },
+ "outputs": [],
+ "source": [
+ "# Create 3 layers\n",
+ "layer1 = layers.Dense(2, activation=\"relu\", name=\"layer1\")\n",
+ "layer2 = layers.Dense(3, activation=\"relu\", name=\"layer2\")\n",
+ "layer3 = layers.Dense(4, name=\"layer3\")\n",
+ "\n",
+ "# Call layers on a test input\n",
+ "x = tf.ones((3, 3))\n",
+ "y = layer3(layer2(layer1(x)))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cdf6d2932c31"
+ },
+ "source": [
+ "在以下情况下,序贯模型**不适用**:\n",
+ "\n",
+ "- 您的模型有多个输入或多个输出\n",
+ "- 您的任何层都有多个输入或多个输出\n",
+ "- 您需要进行层共享\n",
+ "- 您需要非线性拓扑(例如残差连接、多分支模型)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "c5706d9f8eb8"
+ },
+ "source": [
+ "## 创建序贯模型\n",
+ "\n",
+ "您可以通过将层列表传递给序贯构造函数来创建序贯模型:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "8b3eee00f80d"
+ },
+ "outputs": [],
+ "source": [
+ "model = keras.Sequential(\n",
+ " [\n",
+ " layers.Dense(2, activation=\"relu\"),\n",
+ " layers.Dense(3, activation=\"relu\"),\n",
+ " layers.Dense(4),\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1939a7a4a66c"
+ },
+ "source": [
+ "它的层可以通过 `layers` 属性访问:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "49c0448b6da2"
+ },
+ "outputs": [],
+ "source": [
+ "model.layers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "b4c7957e9913"
+ },
+ "source": [
+ "您还可以通过 `add()` 方法增量式创建序贯模型:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "d54fde401054"
+ },
+ "outputs": [],
+ "source": [
+ "model = keras.Sequential()\n",
+ "model.add(layers.Dense(2, activation=\"relu\"))\n",
+ "model.add(layers.Dense(3, activation=\"relu\"))\n",
+ "model.add(layers.Dense(4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "d16278f5a1dc"
+ },
+ "source": [
+ "请注意,还有一个相应的 `pop()` 方法来移除层:序贯模型的行为非常类似于层列表。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "e89f35b73979"
+ },
+ "outputs": [],
+ "source": [
+ "model.pop()\n",
+ "print(len(model.layers)) # 2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "99cb1c9a7c7a"
+ },
+ "source": [
+ "另请注意,序贯构造函数接受 `name` 参数,就像 Keras 中的任何层或模型一样。这对于使用语义上有意义的名称来注释 TensorBoard 计算图非常有用。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "068c2f82e7cb"
+ },
+ "outputs": [],
+ "source": [
+ "model = keras.Sequential(name=\"my_sequential\")\n",
+ "model.add(layers.Dense(2, activation=\"relu\", name=\"layer1\"))\n",
+ "model.add(layers.Dense(3, activation=\"relu\", name=\"layer2\"))\n",
+ "model.add(layers.Dense(4, name=\"layer3\"))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a6247ff17d3a"
+ },
+ "source": [
+ "## 预先指定输入形状\n",
+ "\n",
+ "一般来说,Keras 中的所有层都需要知道其输入的形状,以便能够创建其权重。因此,当您创建这样的层时,它最初没有权重:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "373ecbb5c6bd"
+ },
+ "outputs": [],
+ "source": [
+ "layer = layers.Dense(3)\n",
+ "layer.weights # Empty"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "da150335961e"
+ },
+ "source": [
+ "当第一次在输入上被调用时,它会创建其权重,因为权重的形状取决于输入的形状:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "bf28829ce193"
+ },
+ "outputs": [],
+ "source": [
+ "# Call layer on a test input\n",
+ "x = tf.ones((1, 4))\n",
+ "y = layer(x)\n",
+ "layer.weights # Now it has weights, of shape (4, 3) and (3,)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "e50f21c5f247"
+ },
+ "source": [
+ "当然,这也适用于序贯模型。当您实例化没有输入形状的序贯模型时,它不会被“构建”:它没有权重(并且调用 `model.weights` 会导致说明这一点的错误)。当模型第一次看到一些输入数据时,会创建权重:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "04479960526c"
+ },
+ "outputs": [],
+ "source": [
+ "model = keras.Sequential(\n",
+ " [\n",
+ " layers.Dense(2, activation=\"relu\"),\n",
+ " layers.Dense(3, activation=\"relu\"),\n",
+ " layers.Dense(4),\n",
+ " ]\n",
+ ") # No weights at this stage!\n",
+ "\n",
+ "# At this point, you can't do this:\n",
+ "# model.weights\n",
+ "\n",
+ "# You also can't do this:\n",
+ "# model.summary()\n",
+ "\n",
+ "# Call the model on a test input\n",
+ "x = tf.ones((1, 4))\n",
+ "y = model(x)\n",
+ "print(\"Number of weights after calling the model:\", len(model.weights)) # 6"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2837e3d2c798"
+ },
+ "source": [
+ "一旦模型“已构建”,您就可以调用它的 `summary()` 方法来显示其内容:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "1368bd27f679"
+ },
+ "outputs": [],
+ "source": [
+ "model.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "08cf1da27edb"
+ },
+ "source": [
+ "不过,在增量式构建序贯模型时,它非常有用,能够显示迄今为止模型的摘要,包括当前的输出形状。在这种情况下,您应通过将 `Input` 对象传递给您的模型来启动模型,以便模型从一开始就知道其输入形状:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "e3d2024cfeeb"
+ },
+ "outputs": [],
+ "source": [
+ "model = keras.Sequential()\n",
+ "model.add(keras.Input(shape=(4,)))\n",
+ "model.add(layers.Dense(2, activation=\"relu\"))\n",
+ "\n",
+ "model.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3d965e3761a8"
+ },
+ "source": [
+ "请注意,`Input` 对象不会显示为 `model.layers` 的一部分,因为它不是层:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "8e3b0d58e7ed"
+ },
+ "outputs": [],
+ "source": [
+ "model.layers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8a057b1baf72"
+ },
+ "source": [
+ "一种简单的替代方式是将 `input_shape` 参数传递给第一层:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "1c6ab83d68ea"
+ },
+ "outputs": [],
+ "source": [
+ "model = keras.Sequential()\n",
+ "model.add(layers.Dense(2, activation=\"relu\", input_shape=(4,)))\n",
+ "\n",
+ "model.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "40c14619d283"
+ },
+ "source": [
+ "使用像这样的预定义输入形状构建的模型始终具有权重(甚至在看到任何数据之前),并且始终具有定义的输出形状。\n",
+ "\n",
+ "一般来说,如果您知道序贯模型的输入形状是什么,推荐的最佳做法是始终提前指定它。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "843f6b6505b3"
+ },
+ "source": [
+ "## 常见的调试工作流:`add()` + `summary()`\n",
+ "\n",
+ "在构建新的序贯架构时,使用 `add()` 增量式堆叠层并经常打印模型摘要非常有用。例如,这样便能监控 `Conv2D` 和 `MaxPooling2D` 层的堆栈如何对图像特征映射进行下采样:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "46bfb8f7dc6e"
+ },
+ "outputs": [],
+ "source": [
+ "model = keras.Sequential()\n",
+ "model.add(keras.Input(shape=(250, 250, 3))) # 250x250 RGB images\n",
+ "model.add(layers.Conv2D(32, 5, strides=2, activation=\"relu\"))\n",
+ "model.add(layers.Conv2D(32, 3, activation=\"relu\"))\n",
+ "model.add(layers.MaxPooling2D(3))\n",
+ "\n",
+ "# Can you guess what the current output shape is at this point? Probably not.\n",
+ "# Let's just print it:\n",
+ "model.summary()\n",
+ "\n",
+ "# The answer was: (40, 40, 32), so we can keep downsampling...\n",
+ "\n",
+ "model.add(layers.Conv2D(32, 3, activation=\"relu\"))\n",
+ "model.add(layers.Conv2D(32, 3, activation=\"relu\"))\n",
+ "model.add(layers.MaxPooling2D(3))\n",
+ "model.add(layers.Conv2D(32, 3, activation=\"relu\"))\n",
+ "model.add(layers.Conv2D(32, 3, activation=\"relu\"))\n",
+ "model.add(layers.MaxPooling2D(2))\n",
+ "\n",
+ "# And now?\n",
+ "model.summary()\n",
+ "\n",
+ "# Now that we have 4x4 feature maps, time to apply global max pooling.\n",
+ "model.add(layers.GlobalMaxPooling2D())\n",
+ "\n",
+ "# Finally, we add a classification layer.\n",
+ "model.add(layers.Dense(10))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a2d3335a90fa"
+ },
+ "source": [
+ "非常实用,对吧?\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "46addede37f3"
+ },
+ "source": [
+ "## 有了模型后该怎么办\n",
+ "\n",
+ "一旦模型架构准备就绪,您将需要执行以下操作:\n",
+ "\n",
+ "- 训练您的模型、评估模型并运行推断。请参阅我们的[使用内置循环的训练和评估指南](https://tensorflow.google.cn/guide/keras/train_and_evaluate/)\n",
+ "- 将模型保存到磁盘并将其还原。请参阅我们的[序列化和保存指南](https://tensorflow.google.cn/guide/keras/save_and_serialize/)。\n",
+ "- 利用多个 GPU 加速模型训练。请参阅我们的[多 GPU 和分布式训练指南](https://keras.io/guides/distributed_training/)。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "608f3b03669c"
+ },
+ "source": [
+ "## 使用序贯模型进行特征提取\n",
+ "\n",
+ "一旦构建了序贯模型,它的行为就类似于[函数式 API 模型](https://tensorflow.google.cn/guide/keras/functional/)。这意味着每层都有一个 `input` 和 `output` 属性。这些属性可用于执行一些巧妙的操作,例如快速创建一个模型来提取序贯模型中所有中间层的输出:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "a5888d753301"
+ },
+ "outputs": [],
+ "source": [
+ "initial_model = keras.Sequential(\n",
+ " [\n",
+ " keras.Input(shape=(250, 250, 3)),\n",
+ " layers.Conv2D(32, 5, strides=2, activation=\"relu\"),\n",
+ " layers.Conv2D(32, 3, activation=\"relu\"),\n",
+ " layers.Conv2D(32, 3, activation=\"relu\"),\n",
+ " ]\n",
+ ")\n",
+ "feature_extractor = keras.Model(\n",
+ " inputs=initial_model.inputs,\n",
+ " outputs=[layer.output for layer in initial_model.layers],\n",
+ ")\n",
+ "\n",
+ "# Call feature extractor on test input.\n",
+ "x = tf.ones((1, 250, 250, 3))\n",
+ "features = feature_extractor(x)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4abef35355d3"
+ },
+ "source": [
+ "下面是一个仅从一层提取特征的类似示例:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "fc404c7ac90e"
+ },
+ "outputs": [],
+ "source": [
+ "initial_model = keras.Sequential(\n",
+ " [\n",
+ " keras.Input(shape=(250, 250, 3)),\n",
+ " layers.Conv2D(32, 5, strides=2, activation=\"relu\"),\n",
+ " layers.Conv2D(32, 3, activation=\"relu\", name=\"my_intermediate_layer\"),\n",
+ " layers.Conv2D(32, 3, activation=\"relu\"),\n",
+ " ]\n",
+ ")\n",
+ "feature_extractor = keras.Model(\n",
+ " inputs=initial_model.inputs,\n",
+ " outputs=initial_model.get_layer(name=\"my_intermediate_layer\").output,\n",
+ ")\n",
+ "# Call feature extractor on test input.\n",
+ "x = tf.ones((1, 250, 250, 3))\n",
+ "features = feature_extractor(x)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4e2fb64f0676"
+ },
+ "source": [
+ "## 使用序贯模型进行迁移学习\n",
+ "\n",
+ "迁移学习包括冻结模型中的底层并仅训练顶层。如果您不熟悉迁移学习,请务必阅读我们的[迁移学习指南](https://tensorflow.google.cn/guide/keras/transfer_learning/)。\n",
+ "\n",
+ "下面是涉及序贯模型的两种常见迁移学习蓝图。\n",
+ "\n",
+ "首先,假设您有一个序贯模型,并且想要冻结除最后一层之外的所有层。在这种情况下,只需迭代 `model.layers` 并在除最后一层之外的每一层上设置 `layer.trainable = False`。示例代码如下:\n",
+ "\n",
+ "```python\n",
+ "model = keras.Sequential([\n",
+ " keras.Input(shape=(784)),\n",
+ " layers.Dense(32, activation='relu'),\n",
+ " layers.Dense(32, activation='relu'),\n",
+ " layers.Dense(32, activation='relu'),\n",
+ " layers.Dense(10),\n",
+ "])\n",
+ "\n",
+ "# Presumably you would want to first load pre-trained weights.\n",
+ "model.load_weights(...)\n",
+ "\n",
+ "# Freeze all layers except the last one.\n",
+ "for layer in model.layers[:-1]:\n",
+ " layer.trainable = False\n",
+ "\n",
+ "# Recompile and train (this will only update the weights of the last layer).\n",
+ "model.compile(...)\n",
+ "model.fit(...)\n",
+ "```\n",
+ "\n",
+ "另一个常见的蓝图是使用序贯模型来堆叠预训练模型和一些新初始化的分类层。示例代码如下:\n",
+ "\n",
+ "```python\n",
+ "# Load a convolutional base with pre-trained weights\n",
+ "base_model = keras.applications.Xception(\n",
+ " weights='imagenet',\n",
+ " include_top=False,\n",
+ " pooling='avg')\n",
+ "\n",
+ "# Freeze the base model\n",
+ "base_model.trainable = False\n",
+ "\n",
+ "# Use a Sequential model to add a trainable classifier on top\n",
+ "model = keras.Sequential([\n",
+ " base_model,\n",
+ " layers.Dense(1000),\n",
+ "])\n",
+ "\n",
+ "# Compile & train\n",
+ "model.compile(...)\n",
+ "model.fit(...)\n",
+ "```\n",
+ "\n",
+ "如果您进行迁移学习,您可能会发现自己经常使用这两种模式。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fcffc33b61e5"
+ },
+ "source": [
+ "上面是您需要了解的有关序贯模型的全部信息!\n",
+ "\n",
+ "要详细了解如何在 Keras 中构建模型,请参阅:\n",
+ "\n",
+ "- [函数式 API 指南](https://tensorflow.google.cn/guide/keras/functional/)\n",
+ "- [通过子类化创建新层和模型的指南](https://tensorflow.google.cn/guide/keras/custom_layers_and_models/)"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "sequential_model.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/site/zh-cn/guide/migrate/metrics_optimizers.ipynb b/site/zh-cn/guide/migrate/metrics_optimizers.ipynb
new file mode 100644
index 0000000000..78c03405fa
--- /dev/null
+++ b/site/zh-cn/guide/migrate/metrics_optimizers.ipynb
@@ -0,0 +1,363 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "wJcYs_ERTnnI"
+ },
+ "source": [
+ "##### Copyright 2021 The TensorFlow Authors."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "HMUDt0CiUJk9"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "77z2OchJTk0l"
+ },
+ "source": [
+ "# 迁移指标和优化器\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "meUTrR4I6m1C"
+ },
+ "source": [
+ "在 TF1 中,`tf.metrics` 是所有指标函数的 API 命名空间。每个指标都是一个将 `label` 和 `prediction` 作为输入参数,并返回相应指标张量作为结果的函数。在 TF2 中,`tf.keras.metrics` 包含所有指标函数和对象。`Metric` 对象可以与 `tf.keras.Model` 和 `tf.keras.layers.layer` 一起使用来计算指标值。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YdZSoIXEbhg-"
+ },
+ "source": [
+ "## 安装\n",
+ "\n",
+ "从几个必要的 TensorFlow 导入开始:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "iE0vSfMXumKI"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "import tensorflow.compat.v1 as tf1"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Jsm9Rxx7s1OZ"
+ },
+ "source": [
+ "然后,准备一个用于演示的简单数据:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "m7rnGxsXtDkV"
+ },
+ "outputs": [],
+ "source": [
+ "features = [[1., 1.5], [2., 2.5], [3., 3.5]]\n",
+ "labels = [0, 0, 1]\n",
+ "eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]\n",
+ "eval_labels = [0, 1, 1]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xswk0d4xrFaQ"
+ },
+ "source": [
+ "## TF1:具有 Estimator 的 tf.compat.v1.metrics\n",
+ "\n",
+ "在 TF1 中,指标可以作为 `eval_metric_ops` 添加到 `EstimatorSpec` 中,并且运算通过 `tf.metrics` 中定义的所有指标函数生成。可以按照示例了解如何使用 `tf.metrics.accuracy`。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "lqe9obf7suIj"
+ },
+ "outputs": [],
+ "source": [
+ "def _input_fn():\n",
+ " return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)\n",
+ "\n",
+ "def _eval_input_fn():\n",
+ " return tf1.data.Dataset.from_tensor_slices(\n",
+ " (eval_features, eval_labels)).batch(1)\n",
+ "\n",
+ "def _model_fn(features, labels, mode):\n",
+ " logits = tf1.layers.Dense(2)(features)\n",
+ " predictions = tf.math.argmax(input=logits, axis=1)\n",
+ " loss = tf1.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)\n",
+ " optimizer = tf1.train.AdagradOptimizer(0.05)\n",
+ " train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())\n",
+ " accuracy = tf1.metrics.accuracy(labels=labels, predictions=predictions)\n",
+ " return tf1.estimator.EstimatorSpec(mode, \n",
+ " predictions=predictions,\n",
+ " loss=loss, \n",
+ " train_op=train_op,\n",
+ " eval_metric_ops={'accuracy': accuracy})\n",
+ "\n",
+ "estimator = tf1.estimator.Estimator(model_fn=_model_fn)\n",
+ "estimator.train(_input_fn)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "HsOpjW5plH9Q"
+ },
+ "outputs": [],
+ "source": [
+ "estimator.evaluate(_eval_input_fn)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Wk4C6qA_OaQx"
+ },
+ "source": [
+ "此外,可以通过 `tf.estimator.add_metrics()` 直接将指标添加到 Estimator 中。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "B2lpLOh9Owma"
+ },
+ "outputs": [],
+ "source": [
+ "def mean_squared_error(labels, predictions):\n",
+ " labels = tf.cast(labels, predictions.dtype)\n",
+ " return {\"mean_squared_error\": \n",
+ " tf1.metrics.mean_squared_error(labels=labels, predictions=predictions)}\n",
+ "\n",
+ "estimator = tf1.estimator.add_metrics(estimator, mean_squared_error)\n",
+ "estimator.evaluate(_eval_input_fn)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KEmzBjfnsxwT"
+ },
+ "source": [
+ "## TF2:具有 tf.keras.Model 的 Keras Metrics API\n",
+ "\n",
+ "在 TF2 中,`tf.keras.metrics` 包含所有指标类和函数。它们以 OOP 风格设计,并与其他 `tf.keras` API 紧密集成。所有指标都可以在 `tf.keras.metrics` 命名空间中找到,并且 `tf.compat.v1.metrics` 与 `tf.keras.metrics` 之间通常存在直接映射。\n",
+ "\n",
+ "在以下示例中,指标添加到 `model.compile()` 方法中。用户只需要创建指标实例,无需指定标签和预测张量。Keras 模型会将模型输出和标签发送到指标对象。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "atVciNgPs0fw"
+ },
+ "outputs": [],
+ "source": [
+ "dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)\n",
+ "eval_dataset = tf.data.Dataset.from_tensor_slices(\n",
+ " (eval_features, eval_labels)).batch(1)\n",
+ "\n",
+ "inputs = tf.keras.Input((2,))\n",
+ "logits = tf.keras.layers.Dense(2)(inputs)\n",
+ "predictions = tf.math.argmax(input=logits, axis=1)\n",
+ "model = tf.keras.models.Model(inputs, predictions)\n",
+ "optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)\n",
+ "\n",
+ "model.compile(optimizer, loss='mse', metrics=[tf.keras.metrics.Accuracy()])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Kip65sYBlKiu"
+ },
+ "outputs": [],
+ "source": [
+ "model.evaluate(eval_dataset, return_dict=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_mcGoCm_X1V0"
+ },
+ "source": [
+ "启用 Eager Execution 后,`tf.keras.metrics.Metric` 实例可直接用于评估 numpy 数据或 Eager 张量。`tf.keras.metrics.Metric` 对象是有状态容器。指标值可以通过 `metric.update_state(y_true, y_pred)` 进行更新,结果可以通过 `metrics.result()` 进行检索。\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "TVGn5_IhYhtG"
+ },
+ "outputs": [],
+ "source": [
+ "accuracy = tf.keras.metrics.Accuracy()\n",
+ "\n",
+ "accuracy.update_state(y_true=[0, 0, 1, 1], y_pred=[0, 0, 0, 1])\n",
+ "accuracy.result().numpy()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "wQEV2hHtY_su"
+ },
+ "outputs": [],
+ "source": [
+ "accuracy.update_state(y_true=[0, 0, 1, 1], y_pred=[0, 0, 0, 0])\n",
+ "accuracy.update_state(y_true=[0, 0, 1, 1], y_pred=[1, 1, 0, 0])\n",
+ "accuracy.result().numpy()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "E3F3ElcyadW-"
+ },
+ "source": [
+ "有关 `tf.keras.metrics.Metric` 的更多详情,请查看 `tf.keras.metrics.Metric` 下的 API 文档以及[迁移指南](https://tensorflow.google.cn/guide/effective_tf2#new-style_metrics_and_losses)。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eXKY9HEulxQC"
+ },
+ "source": [
+ "## 将 TF1.x 优化器迁移到 Keras 优化器\n",
+ "\n",
+ "`tf.compat.v1.train` 中的优化器(如 [Adam 优化器](https://tensorflow.google.cn/api_docs/python/tf/compat/v1/train/AdamOptimizer)和[梯度下降优化器](https://tensorflow.google.cn/api_docs/python/tf/compat/v1/train/GradientDescentOptimizer))在 `tf.keras.optimizers` 中具有等效项。\n",
+ "\n",
+ "下表总结了如何将这些旧版优化器转换为 Keras 等效项。除非需要额外的步骤(例如[更新默认学习率](../../guide/effective_tf2.ipynb#optimizer_defaults)),否则可以直接将 TF1.x 版本替换为 TF2 版本。\n",
+ "\n",
+ "请注意,转换优化器[可能会使旧的检查点不兼容](./migrating_checkpoints.ipynb)。\n",
+ "\n",
+ "\n",
+ " \n",
+ " TF1.x | \n",
+ " TF2 | \n",
+ " 额外步骤 | \n",
+ "
\n",
+ " \n",
+ " `tf.v1.train.GradientDescentOptimizer` | \n",
+ " `tf.keras.optimizers.SGD` | \n",
+ " 无 | \n",
+ "
\n",
+ " \n",
+ " `tf.v1.train.MomentumOptimizer` | \n",
+ " `tf.keras.optimizers.SGD` | \n",
+ " 包含 `momentum` 参数 | \n",
+ "
\n",
+ " \n",
+ " `tf.v1.train.AdamOptimizer` | \n",
+ " `tf.keras.optimizers.Adam` | \n",
+ " 将 `beta1` 和 `beta2` 参数重命名为 `beta_1` 和 `beta_2` | \n",
+ "
\n",
+ " \n",
+ " `tf.v1.train.RMSPropOptimizer` | \n",
+ " `tf.keras.optimizers.RMSprop` | \n",
+ " 将 `decay` 参数重命名为 `rho` | \n",
+ "
\n",
+ " \n",
+ " `tf.v1.train.AdadeltaOptimizer` | \n",
+ " `tf.keras.optimizers.Adadelta` | \n",
+ " 无 | \n",
+ "
\n",
+ " \n",
+ " `tf.v1.train.AdagradOptimizer` | \n",
+ " `tf.keras.optimizers.Adagrad` | \n",
+ " 无 | \n",
+ "
\n",
+ " \n",
+ " `tf.v1.train.FtrlOptimizer` | \n",
+ " `tf.keras.optimizers.Ftrl` | \n",
+ " 移除 `accum_name` 和 `linear_name` 参数 | \n",
+ "
\n",
+ " \n",
+ " `tf.contrib.AdamaxOptimizer` | \n",
+ " `tf.keras.optimizers.Adamax` | \n",
+ " 将 `beta1` 和 `beta2` 参数重命名为 `beta_1` 和 `beta_2` | \n",
+ "
\n",
+ " \n",
+ " `tf.contrib.Nadam` | \n",
+ " `tf.keras.optimizers.Nadam` | \n",
+ " 将 `beta1` 和 `beta2` 参数重命名为 `beta_1` 和 `beta_2` | \n",
+ "
\n",
+ "
\n",
+ "\n",
+ "注:在 TF2 中,所有 ε(数值稳定性常数)现在默认为 `1e-7`,而不是 `1e-8`。在大多数用例中,这种差异可以忽略不计。"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "metrics_optimizers.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/site/zh-cn/guide/migrate/migrating_feature_columns.ipynb b/site/zh-cn/guide/migrate/migrating_feature_columns.ipynb
index 532e8f40a1..87ce7bf949 100644
--- a/site/zh-cn/guide/migrate/migrating_feature_columns.ipynb
+++ b/site/zh-cn/guide/migrate/migrating_feature_columns.ipynb
@@ -945,7 +945,7 @@
"\n",
"\n",
"\n",
- "* `output_mode` 可以传递给 `tf.keras.layers.CategoryEncoding`、`tf.keras.layers.StringLookup`、`tf.keras.layers.IntegerLookup` 和 `tf.keras.layers.TextVectorization`。\n",
+ "`output_mode` 可以传递给 `tf.keras.layers.CategoryEncoding`、`tf.keras.layers.StringLookup`、`tf.keras.layers.IntegerLookup` 和 `tf.keras.layers.TextVectorization`。\n",
"\n",
"† `tf.keras.layers.TextVectorization` 可以直接处理自由格式的文本输入(例如,整个句子或段落)。这不是 TensorFlow 1 中分类序列处理的一对一替代,但可以为临时文本预处理提供方便的替代。\n",
"\n",
diff --git a/site/zh-cn/guide/mixed_precision.ipynb b/site/zh-cn/guide/mixed_precision.ipynb
index a33359176e..7e95102830 100644
--- a/site/zh-cn/guide/mixed_precision.ipynb
+++ b/site/zh-cn/guide/mixed_precision.ipynb
@@ -47,10 +47,13 @@
},
"source": [
""
]
},
@@ -62,7 +65,7 @@
"source": [
"## 概述\n",
"\n",
- "混合精度是指训练时在模型中同时使用 16 位和 32 位浮点类型,从而加快运行速度,减少内存使用的一种训练方法。通过让模型的某些部分保持使用 32 位类型以保持数值稳定性,可以缩短模型的单步用时,而在评估指标(如准确率)方面仍可以获得同等的训练效果。本指南介绍如何使用 Keras 混合精度 API 来加快模型速度。利用此 API 可以在现代 GPU 上将性能提高三倍以上,而在 TPU 上可以提高 60%。"
+ "混合精度是指训练时在模型中同时使用 16 位和 32 位浮点类型,这样可以加快运行速度,以及减少内存使用。通过让模型的某些部分使用 32 位类型以保持数值稳定性,可以缩短模型的单步用时,而在评估指标(如准确率)方面仍可以获得同等的训练效果。本指南介绍如何使用 Keras 混合精度 API 来加快模型速度。利用此 API 可以在现代 GPU 上将性能提高三倍以上,而在最新的 Intel CPU 上可以提高两倍以上。"
]
},
{
@@ -73,7 +76,7 @@
"source": [
"如今,大多数模型使用 float32 dtype,这种数据类型占用 32 位内存。但是,还有两种精度较低的 dtype,即 float16 和 bfloat16,它们都是占用 16 位内存。现代加速器使用 16 位 dtype 执行运算的速度更快,因为它们有执行 16 位计算的专用硬件,并且从内存中读取 16 位 dtype 的速度也更快。\n",
"\n",
- "NVIDIA GPU 使用 float16 执行运算的速度比使用 float32 快,而 TPU 使用 bfloat16 执行运算的速度也比使用 float32 快。因此,在这些设备上应尽可能使用精度较低的 dtype。但是,出于对数值的要求,为了让模型训练获得相同的质量,一些变量和计算仍需使用 float32。利用 Keras 混合精度 API,float16 或 bfloat16 可以与 float32 混合使用,从而既可以获得 float16/bfloat16 的性能优势,也可以获得 float32 的数值稳定性。\n",
+ "NVIDIA GPU 使用 float16 执行运算的速度比使用 float32 快,而 TPU 使用 bfloat16 执行运算的速度也比使用 float32 快。因此,在这些设备上应尽可能使用精度较低的 dtype。但是,出于对数值的要求,为了让模型训练获得相同的质量,一些变量和计算仍需使用 float32。利用 Keras 混合精度 API,float16 或 bfloat16 可以与 float32 混合使用,从而既可以获得 float16/bfloat16 的性能优势,也可以获得 float32 的数值稳定性优势。\n",
"\n",
"注:在本指南中,术语“数值稳定性”是指使用较低精度的 dtype(而不是较高精度的 dtype)对模型质量的影响。如果使用 float16 或 bfloat16 执行运算,则与使用 float32 执行运算相比,使用这些较低精度的 dtype 会导致模型获得的评估准确率或其他指标相对较低,那么我们就说这种运算“数值不稳定”。"
]
@@ -84,7 +87,7 @@
"id": "MUXex9ctTuDB"
},
"source": [
- "## 设置"
+ "## 安装"
]
},
{
@@ -110,9 +113,11 @@
"source": [
"## 支持的硬件\n",
"\n",
- "虽然混合精度在大多数硬件上都可以运行,但是在最新的 NVIDIA GPU 和 Cloud TPU 上才能加速模型。NVIDIA GPU 支持混合使用 float16 和 float32,而 TPU 则支持混合使用 bfloat16 和 float32。\n",
+ "虽然混合精度在大多数硬件上都可以运行,但是只有在最新的 NVIDIA GPU、Cloud TPU 和最新的 Intel CPU 上才能加速模型。NVIDIA GPU 支持混合使用 float16 和 float32,而 TPU 则支持混合使用 bfloat16 和 float32。\n",
"\n",
- "在 NVIDIA GPU 中,计算能力为 7.0 或更高的 GPU 可以获得混合精度的最大性能优势,因为这些型号具有称为 Tensor 核心的特殊硬件单元,可以加速 float16 矩阵乘法和卷积运算。旧款 GPU 使用混合精度无法实现数学运算性能优势,不过可以节省内存和带宽,因此也可以在一定程度上提高速度。您可以在 NVIDIA 的 [CUDA GPU 网页](https://developer.nvidia.com/cuda-gpus)上查询 GPU 的计算能力。可以最大程度从混合精度受益的 GPU 示例包括 RTX GPU、V100 和 A100。"
+ "在 NVIDIA GPU 中,计算能力为 7.0 或更高的 GPU 可以获得混合精度的最大性能优势,因为这些型号具有称为 Tensor 核心的特殊硬件单元,可以加速 float16 矩阵乘法和卷积运算。旧款 GPU 使用混合精度无法实现数学运算性能优势,不过可以节省内存和带宽,因此也可以在一定程度上提高速度。您可以在 NVIDIA 的 [CUDA GPU 网页](https://developer.nvidia.com/cuda-gpus)上查询 GPU 的计算能力。可以最大程度从混合精度受益的 GPU 示例包括 RTX GPU、V100 和 A100。\n",
+ "\n",
+ "在 Intel CPU 中,从第四代 Intel 至强处理器(代号 Sapphire Rapids)开始,混合精度将提供最大的性能优势,因为它们可以使用 AMX 指令加速 bfloat16 计算(要求 Tensorflow 2.12 或更高版本)。"
]
},
{
@@ -121,7 +126,7 @@
"id": "-q2hisD60F0_"
},
"source": [
- "注:如果在 Google Colab 中运行本指南中示例,则 GPU 运行时通常会连接 P100。P100 的计算能力为 6.0,预计速度提升不明显。\n",
+ "注:如果在 Google Colab 中运行本指南中示例,GPU 运行时通常会连接 P100。P100 的计算能力为 6.0,预计速度提升不明显。如果在 CPU 运行时上运行,速度可能会变慢,因为运行时可能有一个不支持 AMX 的 CPU。\n",
"\n",
"您可以使用以下命令检查 GPU 类型。如果要使用此命令,必须安装 NVIDIA 驱动程序,否则会引发错误。"
]
@@ -145,7 +150,7 @@
"source": [
"所有 Cloud TPU 均支持 bfloat16。\n",
"\n",
- "即使在预计无法提升速度的 CPU 和旧款 GPU 上,混合精度 API 仍可以用于单元测试、调试或试用 API。不过,在 CPU 上,混合精度的运行速度会明显变慢。"
+ "即使在预计无法提升速度的旧款 Intel CPU、不支持 AMX 的其他 x86 CPU 和旧款 GPU 上,混合精度 API 仍可以用于单元测试、调试或试用 API。但是,不支持 AMX 指令的 CPU 上的 mix_bfloat16 以及所有 x86 CPU 上的 mix_float16 的运行速度会明显变慢。"
]
},
{
@@ -163,7 +168,7 @@
"id": "54ecYY2Hn16E"
},
"source": [
- "要在 Keras 中使用混合精度,您需要创建一条 `tf.keras.mixed_precision.Policy`,通常将其称为 *dtype 策略*。Dtype 策略可以指定将在其中运行的 dtype 层。在本指南中,您将从字符串 `'mixed_float16'` 构造策略,并将其设置为全局策略。这会导致随后创建的层使用 float16 和 float32 的混合精度。"
+ "要在 Keras 中使用混合精度,您需要创建一条 `tf.keras.mixed_precision.Policy`,通常将其称为 *dtype 策略*。Dtype 策略可以指定将在其中运行的 dtype 层。在本指南中,您将从字符串 `'mixed_float16'` 构造策略,并将其设置为全局策略。这会导致随后创建的层使用 float16 和 float32 的混合精度。"
]
},
{
@@ -226,7 +231,7 @@
"id": "MOFEcna28o4T"
},
"source": [
- "如前所述,在计算能力至少为 7.0 的 NVIDIA GPU 上,`mixed_float16` 策略可以大幅提升性能。在其他 GPU 和 CPU 上,该策略也可以运行,但可能无法提升性能。对于 TPU,则应使用 `mixed_bfloat16` 策略。"
+ "如前所述,在计算能力至少为 7.0 的 NVIDIA GPU 上,`mixed_float16` 策略可以大幅提升性能。在其他 GPU 和 CPU 上,该策略也可以运行,但可能无法提升性能。对于 TPU 和 CPU,则应使用 `mixed_bfloat16` 策略。"
]
},
{
@@ -453,7 +458,7 @@
"source": [
"请注意,模型会在日志中打印每个步骤的时间:例如,“25ms/step”。第一个周期可能会变慢,因为 TensorFlow 会花一些时间来优化模型,但之后每个步骤的时间应当会稳定下来。\n",
"\n",
- "如果在 Colab 中运行本指南中,您可以使用 float32 比较混合精度的性能。为此,请在“Setting the dtype policy”部分将策略从 `mixed_float16` 更改为 `float32`,然后重新运行所有代码单元,直到此代码点。在计算能力至少为 7.0 的 GPU 上,您会发现每个步骤的时间大大增加,表明混合精度提升了模型的速度。在继续学习本指南之前,请确保将策略改回 `mixed_float16` 并重新运行代码单元。\n",
+ "如果在 Colab 中运行本指南中,您可以使用 float32 比较混合精度的性能。为此,请在“Setting the dtype policy”部分将策略从 `mixed_float16` 更改为 `float32`,然后重新运行所有代码单元,直到此代码点。在计算能力至少为 7.X 的 GPU 上,您会发现每个步骤的时间大大增加,表明混合精度提升了模型的速度。在继续学习本指南之前,请确保将策略改回 `mixed_float16` 并重新运行代码单元。\n",
"\n",
"在计算能力至少为 8.0 的 GPU(Ampere GPU 及更高版本)上,使用混合精度时,与使用 float32 相比,您可能看不到本指南中小模型的性能提升。这是由于使用 [TensorFloat-32](https://tensorflow.google.cn/api_docs/python/tf/config/experimental/enable_tensor_float_32_execution) 导致的,它会在 `tf.linalg.matmul` 等某些 float32 运算中自动使用较低精度的数学计算。使用 float32 时,TensorFloat-32 会展现混合精度的一些性能优势。不过,在真实模型中,由于内存带宽节省和 TensorFloat-32 不支持的运算,您通常仍会看到混合精度的显著性能提升。\n",
"\n",
@@ -470,7 +475,9 @@
"source": [
"## 损失放大\n",
"\n",
- "损失放大是 `tf.keras.Model.fit` 使用 `mixed_float16` 策略自动执行,从而避免数值下溢的一种技术。本部分介绍什么是损失放大,下一部分介绍如何将其与自定义训练循环一起使用。"
+ "损失放大是 `tf.keras.Model.fit` 使用 `mixed_float16` 策略自动执行,从而避免数值下溢的一种技术。本部分介绍什么是损失放大,下一部分介绍如何将其与自定义训练循环一起使用。\n",
+ "\n",
+ "注:使用 `mixed_bfloat16` 策略时,不需要进行损失缩放。"
]
},
{
@@ -516,7 +523,7 @@
"id": "pUIbhQypRVe_"
},
"source": [
- "实际上,float16 也极少出现下溢的情况。此外,在正向传递中出现下溢的情形更是十分罕见。但是,在反向传递中,梯度可能因下溢而变为零。损失放大就是一个防止出现下溢的技巧。"
+ "实际上,float16 也极少出现下溢的情况。此外,在前向传递中出现下溢的情形更是十分罕见。但是,在后向传递中,梯度可能因下溢而变为零。损失放大就是一个防止出现下溢的技巧。"
]
},
{
@@ -527,7 +534,7 @@
"source": [
"### 损失放大概述\n",
"\n",
- "损失放大的基本概念非常简单:只需将损失乘以某个大数字(如 $1024$)即可得到*损失放大{/em0值。这会将梯度放大 $1024$ 倍,大大降低了发生下溢的几率。计算出最终梯度后,将其除以 $1024$ 即可得到正确值。*\n",
+ "损失放大的基本概念非常简单:只需将损失乘以某个大数字(如 $1024$)即可得到*损失放大{/em0}值。这会将梯度放大 $1024$ 倍,大大降低了发生下溢的几率。计算出最终梯度后,将其除以 $1024$ 即可得到正确值。*\n",
"\n",
"该过程的伪代码是:\n",
"\n",
@@ -639,7 +646,7 @@
"- `get_scaled_loss(loss)`:将损失值乘以损失标度值\n",
"- `get_unscaled_gradients(gradients)`:获取一系列放大的梯度作为输入,并将每一个梯度除以损失标度,从而将其缩小为实际值\n",
"\n",
- "为了防止梯度发生下溢,必须使用这些函数。随后,如果全部没有出现 Inf 或 NaN 值,则 `LossScaleOptimizer.apply_gradients` 会应用这些梯度。它还会更新损失标度,如果梯度出现 Inf 或 NaN 值,则会将其减半,而如果出现零值,则会增大损失标度。"
+ "为了防止梯度发生下溢,必须使用这些函数。随后,如果全部没有出现 `Inf` 或 `NaN` 值,则 `LossScaleOptimizer.apply_gradients` 会应用这些梯度。它还会更新损失标度,如果梯度出现 `Inf` 或 `NaN` 值,则会将其减半,而如果出现零值,则会增大损失标度。"
]
},
{
@@ -796,17 +803,18 @@
"source": [
"## 总结\n",
"\n",
- "- 如果您使用的是计算能力至少为 7.0 的 TPU 或 NVIDIA GPU,则应使用混合精度,因为它可以将性能提升多达 3 倍。\n",
+ "- 如果您使用的是计算能力至少为 7.0 的 TPU 和 NVIDIA GPU 或支持 AMX 指令的 Intel CPU,则应使用混合精度,因为它可以将性能提升多达 3 倍。\n",
"\n",
"- 您可以按如下代码使用混合精度:\n",
"\n",
" ```python\n",
- " # On TPUs, use 'mixed_bfloat16' instead\n",
+ " # On TPUs and CPUs, use 'mixed_bfloat16' instead\n",
" mixed_precision.set_global_policy('mixed_float16')\n",
" ```\n",
"\n",
"- 如果您的模型以 softmax 结尾,请确保其类型为 float32。不管您的模型以什么结尾,必须确保输出为 float32。\n",
"- 如果您通过 `mixed_float16` 使用自定义训练循环,则除了上述几行代码外,您还需要使用 `tf.keras.mixed_precision.LossScaleOptimizer` 封装您的优化器。然后调用 `optimizer.get_scaled_loss` 来放大损失,并且调用 `optimizer.get_unscaled_gradients` 来缩小梯度。\n",
+ "- 如果您正在通过 `mixed_bfloat16` 使用自定义训练循环,则设置上面提到的 global_policy 已足够。\n",
"- 如果不会降低计算准确率,则可以将训练批次大小加倍。\n",
"- 在 GPU 上,确保大部分张量维度是 $8$ 的倍数,从而最大限度提高性能\n",
"\n",
diff --git a/site/zh-cn/guide/profiler.md b/site/zh-cn/guide/profiler.md
index 380201d106..5430664804 100644
--- a/site/zh-cn/guide/profiler.md
+++ b/site/zh-cn/guide/profiler.md
@@ -126,7 +126,7 @@ Profiler 提供了多种工具来帮助您进行性能分析:
要打开输入流水线分析器,请选择 **Profile**,然后在 **Tools** 下拉列表中选择 **input_pipeline_analyzer**。
- ![image](./images/tf_profiler/overview_page.png?raw=true)
+![image](./images/tf_profiler/overview_page.png?raw=true)
信息中心包含三个版块:
@@ -221,7 +221,7 @@ Trace Viewer 会显示一个包含以下信息的时间线:
当您打开 Trace Viewer 时,它会显示您最近的运行:
- ![image](./images/tf_profiler/gpu_kernel_stats.png?raw=true)
+![image](./images/tf_profiler/gpu_kernel_stats.png?raw=true)
此画面包含以下主要元素:
@@ -252,7 +252,7 @@ Trace Viewer 包含以下版块:
Trace Viewer 还可以显示您的 TensorFlow 程序中 Python 函数调用的跟踪记录。如果您使用 `tf.profiler.experimental.start()` API,可以在开始性能剖析时使用 `ProfilerOptions` 命名元组启用 Python 跟踪。或者,如果您使用采样模式进行性能剖析,可以使用 **Capture Profile** 对话框中的下拉选项选择跟踪级别。
- ![image](./images/tf_profiler/overview_page.png?raw=true)
+![image](./images/tf_profiler/overview_page.png?raw=true)
@@ -322,11 +322,11 @@ Trace Viewer 还可以显示您的 TensorFlow 程序中 Python 函数调用的
此版块显示了内存使用量(以 GiB 为单位)以及碎片百分比与时间(以毫秒为单位)关系图。
- ![image](./images/tf_profiler/memory_timeline_graph.png?raw=true)
+![image](./images/tf_profiler/memory_timeline_graph.png?raw=true)
X 轴表示性能剖析间隔的时间线(以毫秒为单位)。左侧的 Y 轴表示内存使用量(以 GiB 为单位),右侧的 Y 轴表示碎片百分比。在 X 轴上的每个时间点,总内存都分为三类:堆栈(红色)、堆(橙色)和可用(绿色)。将鼠标悬停在特定的时间戳上可以查看有关此时内存分配/释放事件的详细信息,具体如下所示:
- ![image](./images/tf_profiler/memory_timeline_graph_popup.png?raw=true)
+![image](./images/tf_profiler/memory_timeline_graph_popup.png?raw=true)
弹出窗口显示以下信息:
@@ -365,7 +365,7 @@ X 轴表示性能剖析间隔的时间线(以毫秒为单位)。左侧的 Y
Pod Viewer 工具可以显示一个训练步骤在所有工作进程中的细分。
- ![image](./images/tf_profiler/pod_viewer.png?raw=true)
+![image](./images/tf_profiler/pod_viewer.png?raw=true)
- 上部窗格具有用于选择步骤编号的滑块。
- 下部窗格显示堆叠的柱状图。这是细分的步骤-时间类别彼此叠加的高级视图。每个堆叠的柱状图代表一个唯一的工作进程。
@@ -391,7 +391,7 @@ Pod Viewer 工具可以显示一个训练步骤在所有工作进程中的细分
#### Performance Analysis Summary
- ![image](./images/tf_profiler/memory_breakdown_table.png?raw=true)
+![image](./images/tf_profiler/memory_breakdown_table.png?raw=true)
此版块提供了分析的摘要。它会报告在性能剖析中是否检测到较慢的 `tf.data` 输入流水线。此版块还会显示最受输入约束的主机及其最慢且具有最大延迟的输入流水线。最重要的是,它会识别输入流水线的哪一部分是瓶颈以及如何解决该瓶颈。瓶颈信息通过迭代器类型及其长名称提供。
@@ -416,7 +416,7 @@ dataset = tf.data.Dataset.range(10).map(lambda x: x).repeat(2).batch(5)
#### 所有输入流水线的摘要
- ![image](./images/tf_profiler/tf_data_all_hosts.png?raw=true)
+![image](./images/tf_profiler/tf_data_all_hosts.png?raw=true)
本部分提供了所有主机上的所有输入流水线的摘要。通常只有一个输入流水线。使用分配策略时,有一个主机输入流水线运行程序的 `tf.data` 代码,多个设备输入流水线从主机输入流水线中检索数据并将其传送到设备。
diff --git a/site/zh-cn/guide/saved_model.ipynb b/site/zh-cn/guide/saved_model.ipynb
index d0ac69fb70..6b98e7994e 100644
--- a/site/zh-cn/guide/saved_model.ipynb
+++ b/site/zh-cn/guide/saved_model.ipynb
@@ -47,10 +47,11 @@
},
"source": [
""
]
},
@@ -79,8 +80,24 @@
"id": "9SuIC7FiI9g8"
},
"source": [
- "## 从 Keras 创建 SavedModel\n",
- "\n",
+ "## 从 Keras 创建 SavedModel"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AtSmftAvhJvE"
+ },
+ "source": [
+ "已弃用:对于 Keras 对象,建议使用新的高级 `.keras` 格式和 `tf.keras.Model.export`,如[此处](https://tensorflow.google.cn/guide/keras/save_and_serialize)的指南所示。对于现有代码,继续支持低级 SavedModel 格式。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eLSOptpYhJvE"
+ },
+ "source": [
"为便于简单介绍,本部分将导出一个预训练 Keras 模型来处理图像分类请求。本指南的其他部分将详细介绍和讨论创建 SavedModel 的其他方式。"
]
},
@@ -179,7 +196,7 @@
"id": "r4KIsQDZJ5PS"
},
"source": [
- "对该图像的顶部预测是“军服”。"
+ "对此图像的热门预测是“军服”。"
]
},
{
@@ -200,7 +217,7 @@
"id": "pyX-ETE3wX63"
},
"source": [
- "保存路径遵循 TensorFlow Serving 使用的惯例,路径的最后一个部分(此处为 `1/`)是模型的版本号——它可以让 Tensorflow Serving 之类的工具推断相对新鲜度。\n",
+ "保存路径遵循 TensorFlow Serving 使用的惯例,路径的最后一个部分(此处为 `1/`)是模型的版本号:它可以让 Tensorflow Serving 之类的工具推断相对新鲜度。\n",
"\n",
"您可以使用 `tf.saved_model.load` 将 SavedModel 加载回 Python,并查看 Admiral Hopper 的图像是如何分类的。"
]
@@ -270,7 +287,7 @@
"source": [
"## 在 TensorFlow Serving 中运行 SavedModel\n",
"\n",
- "可以通过 Python 使用 SavedModel(下文中有详细介绍),但是,生产环境通常会使用专门服务进行推理,而不会运行 Python 代码。使用 TensorFlow Serving 时,这很容易从 SavedModel 进行设置。\n",
+ "可以通过 Python 使用 SavedModel(下文中有详细介绍),但是,生产环境通常会使用专门服务进行推断,而不会运行 Python 代码。使用 TensorFlow Serving 时,这很容易从 SavedModel 进行设置。\n",
"\n",
"请参阅 [TensorFlow Serving REST 教程](https://tensorflow.google.cn/tfx/tutorials/serving/rest_simple)了解端到端 tensorflow-serving 示例。"
]
@@ -283,7 +300,7 @@
"source": [
"## 磁盘上的 SavedModel 格式\n",
"\n",
- "SavedModel 是一个包含序列化签名和运行这些签名所需的状态的目录,其中包括变量值和词汇表。\n"
+ "SavedModel 是一个包含序列化签名和运行这些签名所需的状态的目录,其中包括变量值和词汇。\n"
]
},
{
@@ -303,9 +320,9 @@
"id": "ple4X5utX8ue"
},
"source": [
- "`saved_model.pb` 文件用于存储实际 TensorFlow 程序或模型,以及一组已命名的签名——每个签名标识一个接受张量输入和产生张量输出的函数。\n",
+ "`saved_model.pb` 文件用于存储实际 TensorFlow 程序或模型,以及一组已命名的签名,每个签名标识一个接受张量输入和产生张量输出的函数。\n",
"\n",
- "SavedModel 可能包含模型的多个变体(多个 `v1.MetaGraphDefs`,通过 `saved_model_cli` 的 `--tag_set` 标记进行标识),但这种情况很少见。可以为模型创建多个变体的 API 包括 [tf.Estimator.experimental_export_all_saved_models](https://tensorflow.google.cn/api_docs/python/tf/estimator/Estimator#experimental_export_all_saved_models) 和 TensorFlow 1.x 中的 `tf.saved_model.Builder`。"
+ "SavedModel 可能包含模型的多个变体(多个 `v1.MetaGraphDefs`,通过 `saved_model_cli` 的 `--tag_set` 标志进行标识),但这种情况很少见。可以为模型创建多个变体的 API 包括 [`tf.Estimator.experimental_export_all_saved_models`](https://tensorflow.google.cn/api_docs/python/tf/estimator/Estimator#experimental_export_all_saved_models) 和 TensorFlow 1.x 中的 `tf.saved_model.Builder`。"
]
},
{
@@ -347,7 +364,9 @@
"source": [
"`assets` 目录包含 TensorFlow 计算图使用的文件,例如,用于初始化词汇表的文本文件。本例中没有使用这种文件。\n",
"\n",
- "SavedModel 可能有一个用于保存 TensorFlow 计算图未使用的任何文件的 `assets.extra` 目录,例如,为使用者提供的关于如何处理 SavedModel 的信息。TensorFlow 本身并不会使用此目录。"
+ "SavedModel 可能有一个用于保存 TensorFlow 计算图未使用的任何文件的 `assets.extra` 目录,例如,为使用者提供的关于如何处理 SavedModel 的信息。TensorFlow 本身并不会使用此目录。\n",
+ "\n",
+ "`fingerprint.pb` 文件包含 SavedModel 的[指纹](https://en.wikipedia.org/wiki/Fingerprint_(computing)),它由几个 64 位哈希组成,以唯一的方式标识 SavedModel 的内容。指纹 API 目前处于实验阶段,但 `tf.saved_model.experimental.read_fingerprint` 可以用于将 SavedModel 指纹读取到 `tf.saved_model.experimental.Fingerprint` 对象中。"
]
},
{
@@ -395,11 +414,11 @@
"id": "J4FcP-Co3Fnw"
},
"source": [
- "当您保存 `tf.Module` 时,任何 `tf.Variable` 特性、`tf.function` 装饰的方法以及通过递归遍历找到的 `tf.Module` 都会得到保存。(参阅[检查点教程](./checkpoint.ipynb),了解此递归便利的详细信息。)但是,所有 Python 特性、函数和数据都会丢失。也就是说,当您保存 `tf.function` 时,不会保存 Python 代码。\n",
+ "当您保存 `tf.Module` 时,任何 `tf.Variable` 特性、`tf.function` 装饰的方法以及通过递归遍历找到的 `tf.Module` 都会得到保存。(参阅[检查点教程](./checkpoint.ipynb),了解此递归遍历的详细信息。)但是,所有 Python 特性、函数和数据都会丢失。也就是说,当您保存 `tf.function` 时,不会保存 Python 代码。\n",
"\n",
"如果不保存 Python 代码,SavedModel 如何知道怎样恢复函数?\n",
"\n",
- "简单地说,`tf.function` 的工作原理是,通过跟踪 Python 代码来生成 ConcreteFunction(一个可调用的 `tf.Graph` 包装器)。当您保存 `tf.function` 时,实际上保存的是 `tf.function` 的 ConcreteFunction 缓存。\n",
+ "简单地说,`tf.function` 的工作原理是,通过跟踪 Python 代码来生成 ConcreteFunction(一个可调用的 `tf.Graph` 封装容器)。当您保存 `tf.function` 时,实际上保存的是 `tf.function` 的 ConcreteFunction 缓存。\n",
"\n",
"要详细了解 `tf.function` 与 ConcreteFunction 之间的关系,请参阅 [tf.function 指南](function.ipynb)。"
]
@@ -433,7 +452,7 @@
"id": "QpxQy5Eb77qJ"
},
"source": [
- "在 Python 中加载 SavedModel 时,所有 `tf.Variable` 特性、`tf.function` 装饰方法和 `tf.Module` 都会按照与原始保存的 `tf.Module` 相同对象结构进行恢复。"
+ "在 Python 中加载 SavedModel 时,所有 `tf.Variable` 特性、`tf.function` 装饰方法和 `tf.Module` 都会按照与原始保存的 `tf.Module` 相同的对象结构进行恢复。"
]
},
{
@@ -519,7 +538,7 @@
"\n",
"与普通 `__call__` 相比,Keras 的 SavedModel 提供了[更多详细信息](https://github.com/tensorflow/community/blob/master/rfcs/20190509-keras-saved-model.md#serialization-details)来解决更复杂的微调情形。TensorFlow Hub 建议在共享的 SavedModel 中提供以下详细信息(如果适用),以便进行微调:\n",
"\n",
- "- 如果模型使用随机失活,或者是训练与推理之间的前向传递不同的另一种技术(如批次归一化),则 `__call__` 方法会获取一个可选的 Python 值 `training=` 参数。该参数的默认值为 `False`,但可将其设置为 `True`。\n",
+ "- 如果模型使用随机失活,或者是训练与推断之间的前向传递不同的另一种技术(如批次归一化),则 `__call__` 方法会获取一个可选的 Python 值 `training=` 参数。该参数的默认值为 `False`,但可将其设置为 `True`。\n",
"- 对于变量的对应列表,除了 `__call__` 特性,还有 `.variable` 和 `.trainable_variable` 特性。在微调过程中,`.trainable_variables` 省略了一个变量,该变量原本可训练,但打算将其冻结。\n",
"- 对于 Keras 等将权重正则化项表示为层或子模型特性的框架,还有一个 `.regularization_losses` 特性。它包含一个零参数函数的列表,这些函数的值应加到总损失中。\n",
"\n",
@@ -588,7 +607,7 @@
"id": "BiNtaMZSI8Tb"
},
"source": [
- "要声明服务上线签名,请使用 `signatures` 关键字参数指定 ConcreteFunction。当指定单个签名时,签名键为 `'serving_default'`,并将保存为常量 `tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY`。"
+ "要声明应用签名,请使用 `signatures` 关键字参数指定 ConcreteFunction。指定单个签名时,签名键为 `'serving_default'`,并将保存为常量 `tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY`。"
]
},
{
@@ -674,7 +693,7 @@
" super(CustomModuleWithOutputName, self).__init__()\n",
" self.v = tf.Variable(1.)\n",
"\n",
- " @tf.function(input_signature=[tf.TensorSpec([], tf.float32)])\n",
+ " @tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])\n",
" def __call__(self, x):\n",
" return {'custom_output_name': x * self.v}\n",
"\n",
@@ -697,6 +716,38 @@
"imported_with_output_name.signatures['serving_default'].structured_outputs"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Q4bCK55x1IBW"
+ },
+ "source": [
+ "## proto 分割\n",
+ "\n",
+ "注:此功能将成为 TensorFlow 2.15 版本的一部分。它目前在 Nightly 版本中提供,您可以使用 `pip install tf-nightly` 进行安装。\n",
+ "\n",
+ "由于 protobuf 实现的限制,proto 的大小不能超过 2GB。在尝试保存非常大的模型时,这可能会导致以下错误:\n",
+ "\n",
+ "```\n",
+ "ValueError: Message tensorflow.SavedModel exceeds maximum protobuf size of 2GB: ...\n",
+ "```\n",
+ "\n",
+ "```\n",
+ "google.protobuf.message.DecodeError: Error parsing message as the message exceeded the protobuf limit with type 'tensorflow.GraphDef'\n",
+ "```\n",
+ "\n",
+ "如果您希望保存超过 2GB 限制的模型,则需要使用新的 proto 分割选项进行保存:\n",
+ "\n",
+ "```python\n",
+ "tf.saved_model.save(\n",
+ " ...,\n",
+ " options=tf.saved_model.SaveOptions(experimental_image_format=True)\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "更多信息,请参阅 [Proto 分割器/合并器库指南](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/proto_splitter/in-depth-guide.md)。"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {
@@ -724,9 +775,9 @@
"source": [
"\n",
"\n",
- "## SavedModel 命令行界面详解\n",
+ "## SavedModel 命令行接口详细信息\n",
"\n",
- "使用 SavedModel 命令行界面 (CLI) 可以检查和执行 SavedModel。例如,您可以使用 CLI 来检查模型的 `SignatureDef`。通过 CLI,您可以快速确认与模型相符的输入张量的 dtype 和形状。此外,如果要测试模型,您可以通过 CLI 传入各种格式的样本输入(例如,Python 表达式),然后获取输出,从而执行健全性检查。\n",
+ "您可以使用 SavedModel 命令行接口 (CLI) 检查和执行 SavedModel。例如,您可以使用 CLI 来检查模型的 `SignatureDef`。通过 CLI,您可以快速确认与模型相符的输入张量的 dtype 和形状。此外,如果要测试模型,您可以传入各种格式的样本输入(例如,Python 表达式),然后获取输出,使用 CLI 执行健全性检查。\n",
"\n",
"### 安装 SavedModel CLI\n",
"\n",
@@ -752,7 +803,7 @@
"\n",
"### `show` 命令\n",
"\n",
- "SavedModel 包含一个或多个模型变体(技术为 `v1.MetaGraphDef`),这些变体通过 tag-set 进行标识。要为模型提供服务,您可能想知道每个模型变体中使用的具体是哪一种 `SignatureDef` ,以及它们的输入和输出是什么。那么,利用 `show` 命令,您就可以按照层级顺序检查 SavedModel 的内容。具体语法如下:\n",
+ "SavedModel 包含一个或多个模型变体(从技术上说,为 `v1.MetaGraphDef`),这些变体通过 tag-set 进行标识。要应用模型,您可能想知道每个模型变体中使用的具体是哪一种 `SignatureDef` ,以及它们的输入和输出是什么。那么,利用 `show` 命令,您就可以按照层级顺序检查 SavedModel 的内容。具体语法如下:\n",
"\n",
"```\n",
"usage: saved_model_cli show [-h] --dir DIR [--all]\n",
@@ -771,15 +822,7 @@
"以下命令会显示 tag-set 的所有可用 `SignatureDef` 键:\n",
"\n",
"```\n",
- "$ saved_model_cli show --dir /tmp/saved_model_dir --tag_set serve\n",
- "The given SavedModel `MetaGraphDef` contains `SignatureDefs` with the\n",
- "following keys:\n",
- "SignatureDef key: \"classify_x2_to_y3\"\n",
- "SignatureDef key: \"classify_x_to_y\"\n",
- "SignatureDef key: \"regress_x2_to_y3\"\n",
- "SignatureDef key: \"regress_x_to_y\"\n",
- "SignatureDef key: \"regress_x_to_y2\"\n",
- "SignatureDef key: \"serving_default\"\n",
+ "$ saved_model_cli show --dir /tmp/saved_model_dir --tag_set serve The given SavedModel `MetaGraphDef` contains `SignatureDefs` with the following keys: SignatureDef key: \"classify_x2_to_y3\" SignatureDef key: \"classify_x_to_y\" SignatureDef key: \"regress_x2_to_y3\" SignatureDef key: \"regress_x_to_y\" SignatureDef key: \"regress_x_to_y2\" SignatureDef key: \"serving_default\"\n",
"```\n",
"\n",
"如果 tag-set 中有*多个*标记,则必须指定所有标记(标记之间用逗号分隔)。例如:\n",
@@ -789,19 +832,7 @@
"要显示特定 `SignatureDef` 的所有输入和输出 TensorInfo,请将 `SignatureDef` 键传递给 `signature_def` 选项。如果您想知道输入张量的张量键值、dtype 和形状,以便随后执行计算图,这会非常有用。例如:\n",
"\n",
"```\n",
- "$ saved_model_cli show --dir \\\n",
- "/tmp/saved_model_dir --tag_set serve --signature_def serving_default\n",
- "The given SavedModel SignatureDef contains the following input(s):\n",
- " inputs['x'] tensor_info:\n",
- " dtype: DT_FLOAT\n",
- " shape: (-1, 1)\n",
- " name: x:0\n",
- "The given SavedModel SignatureDef contains the following output(s):\n",
- " outputs['y'] tensor_info:\n",
- " dtype: DT_FLOAT\n",
- " shape: (-1, 1)\n",
- " name: y:0\n",
- "Method name is: tensorflow/serving/predict\n",
+ "$ saved_model_cli show --dir \\ /tmp/saved_model_dir --tag_set serve --signature_def serving_default The given SavedModel SignatureDef contains the following input(s): inputs['x'] tensor_info: dtype: DT_FLOAT shape: (-1, 1) name: x:0 The given SavedModel SignatureDef contains the following output(s): outputs['y'] tensor_info: dtype: DT_FLOAT shape: (-1, 1) name: y:0 Method name is: tensorflow/serving/predict\n",
"```\n",
"\n",
"要显示 SavedModel 中的所有可用信息,请使用 `--all` 选项。例如:\n",
@@ -887,7 +918,6 @@
],
"metadata": {
"colab": {
- "collapsed_sections": [],
"name": "saved_model.ipynb",
"toc_visible": true
},
diff --git a/site/zh-cn/guide/tensor.ipynb b/site/zh-cn/guide/tensor.ipynb
index 1d4be66882..208d582a25 100644
--- a/site/zh-cn/guide/tensor.ipynb
+++ b/site/zh-cn/guide/tensor.ipynb
@@ -47,13 +47,10 @@
},
"source": [
""
]
},
@@ -169,19 +166,15 @@
"source": [
"\n",
"\n",
- " 一个标量,形状:[] \n",
- " | \n",
+ " 一个标量,形状:[] | \n",
" 向量,形状:[3] | \n",
" 矩阵,形状:[3, 2] | \n",
"
\n",
"\n",
- " \n",
- " | \n",
+ " | \n",
"\n",
- " \n",
- " | \n",
- " \n",
- " | \n",
+ " | \n",
+ " | \n",
"
\n",
"
\n"
]
@@ -238,13 +231,10 @@
"\n",
"
\n",
"\n",
- " \n",
- " | \n",
- " \n",
- " | \n",
+ " | \n",
+ " | \n",
"\n",
- " \n",
- " | \n",
+ " | \n",
"
\n",
"\n",
""
@@ -470,10 +460,8 @@
" 4 秩张量,形状:[3, 2, 4, 5] | \n",
"\n",
"\n",
- " \n",
- " | \n",
- " \n",
- " | \n",
+ " | \n",
+ " | \n",
"
\n",
"\n"
]
@@ -538,8 +526,7 @@
"典型的轴顺序 | \n",
"\n",
"\n",
- " \n",
- " | \n",
+ " | \n",
"
\n",
""
]
@@ -737,8 +724,7 @@
"\n",
"\n",
" | \n",
- " \n",
- " | \n",
+ " | \n",
"
\n",
""
]
@@ -760,7 +746,7 @@
"source": [
"## 操作形状\n",
"\n",
- "张量重构很有用。\n"
+ "改变张量的形状很有用。\n"
]
},
{
@@ -895,12 +881,9 @@
"\n",
"一些正确的重构示例。 | \n",
"\n",
- " \n",
- " | \n",
- " \n",
- " | \n",
- " \n",
- " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
"
\n",
"
\n"
]
@@ -948,12 +931,9 @@
"\n",
"一些错误的重构示例。 | \n",
"\n",
- " \n",
- " | \n",
- " \n",
- " | \n",
- " \n",
- " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
"
\n",
"
"
]
@@ -1073,8 +1053,7 @@
" 广播相加:[3, 1] 乘以 [1, 4] 的结果是 [3,4] | \n",
"\n",
"\n",
- " \n",
- " | \n",
+ " | \n",
"
\n",
"\n"
]
@@ -1179,8 +1158,7 @@
" “tf.RaggedTensor”,形状:[4, None] | \n",
"\n",
"\n",
- " \n",
- " | \n",
+ " | \n",
"
\n",
""
]
@@ -1310,8 +1288,7 @@
" 字符串向量,形状:[3,] | \n",
"\n",
"\n",
- " \n",
- " | \n",
+ " | \n",
"
\n",
""
]
@@ -1406,8 +1383,7 @@
" 三个字符串分割,形状:[3, None] | \n",
"\n",
"\n",
- " \n",
- " | \n",
+ " | \n",
"
\n",
""
]
@@ -1505,8 +1481,7 @@
" “tf.SparseTensor”,形状:[3, 4] | \n",
"\n",
"\n",
- " \n",
- " | \n",
+ " | \n",
"
\n",
""
]
diff --git a/site/zh-cn/guide/tf_numpy.ipynb b/site/zh-cn/guide/tf_numpy.ipynb
index 01fda6d4ee..0589e9793d 100644
--- a/site/zh-cn/guide/tf_numpy.ipynb
+++ b/site/zh-cn/guide/tf_numpy.ipynb
@@ -48,9 +48,12 @@
"source": [
""
]
},
@@ -62,7 +65,7 @@
"source": [
"## 概述\n",
"\n",
- "TensorFlow 实现了一部分 [NumPy API](https://numpy.org/doc/1.16),这些 API 以 `tf.experimental.numpy` 形式提供。这样可以运行由 TensorFlow 加速的 NumPy 代码,并使用 TensorFlow 的所有 API。"
+ "TensorFlow 实现了一部分 [NumPy API](https://numpy.org/doc/stable/index.html),这些 API 以 `tf.experimental.numpy` 形式提供。这样可以运行由 TensorFlow 加速的 NumPy 代码,并使用 TensorFlow 的所有 API。"
]
},
{
@@ -134,7 +137,7 @@
"\n",
"称为 **ND Array** 的实例 `tf.experimental.numpy.ndarray` 表示放置在特定设备上的给定 `dtype` 的多维密集数组。它是 `tf.Tensor` 的别名。请查看 ND 数组类来获取有用的方法,例如 `ndarray.T`、`ndarray.reshape`、`ndarray.ravel` 等。\n",
"\n",
- "首先,创建一个 ND 数组对象,然后调用不同的方法。 "
+ "首先,创建一个 ND 数组对象,然后调用不同的方法。"
]
},
{
@@ -162,11 +165,28 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "Mub8-dvJMUr4"
+ "id": "-BOY8CGRKEhE"
},
"source": [
"### 类型提升\n",
"\n",
+ "TensorFlow 中的类型提升有 4 个选项。\n",
+ "\n",
+ "- 默认情况下,TensorFlow 会引发错误,而不是提升混合类型运算的类型。\n",
+ "- 运行 `tf.numpy.experimental_enable_numpy_behavior()` 会将 TensorFlow 切换为使用 `NumPy` 类型提升规则(如下所述)。\n",
+ "- 在 TensorFlow 2.15 之后,有两个新选项(有关详细信息,请参阅 [TF NumPy 类型提升](tf_numpy_type_promotion.ipynb)):\n",
+ " - `tf.numpy.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")` 使用 Jax 类型提升规则。\n",
+ " - `tf.numpy.experimental_enable_numpy_behavior(dtype_conversion_mode=\"safe\")` 使用 Jax 类型提升规则,但不允许某些不安全的提升。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SXskSHrX5J45"
+ },
+ "source": [
+ "#### NumPy 类型提升\n",
+ "\n",
"TensorFlow NumPy API 具有明确定义的语义,可用于将文字转换为 ND 数组,以及对 ND 数组输入执行类型提升。有关更多详细信息,请参阅 [`np.result_type`](https://numpy.org/doc/1.16/reference/generated/numpy.result_type.html)。"
]
},
@@ -192,7 +212,7 @@
" (tnp.int32, tnp.int64, tnp.float32, tnp.float64)]\n",
"for i, v1 in enumerate(values):\n",
" for v2 in values[i + 1:]:\n",
- " print(\"%s + %s => %s\" % \n",
+ " print(\"%s + %s => %s\" %\n",
" (v1.dtype.name, v2.dtype.name, (v1 + v2).dtype.name))"
]
},
@@ -921,7 +941,6 @@
"metadata": {
"accelerator": "GPU",
"colab": {
- "collapsed_sections": [],
"name": "tf_numpy.ipynb",
"toc_visible": true
},
diff --git a/site/zh-cn/guide/tf_numpy_type_promotion.ipynb b/site/zh-cn/guide/tf_numpy_type_promotion.ipynb
new file mode 100644
index 0000000000..edcdcc2e88
--- /dev/null
+++ b/site/zh-cn/guide/tf_numpy_type_promotion.ipynb
@@ -0,0 +1,1133 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZjN_IJ8mhJ-4"
+ },
+ "source": [
+ "##### Copyright 2023 The TensorFlow Authors."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "sY3Ffd83hK3b"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "03Pw58e6mTHI"
+ },
+ "source": [
+ "# TF-NumPy 类型提升"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "l9nPKvxK-_pM"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "uma-W5v__DYh"
+ },
+ "source": [
+ "## 文本特征向量\n",
+ "\n",
+ "TensorFlow 中的类型提升有 4 个选项。\n",
+ "\n",
+ "- 默认情况下,TensorFlow 会引发错误,而不是提升混合类型运算的类型。\n",
+ "- 运行 `tf.numpy.experimental_enable_numpy_behavior()` 会将 TensorFlow 切换为使用 [NumPy 类型提升规则](https://tensorflow.google.cn/guide/tf_numpy#type_promotion)。\n",
+ "- **本文档**介绍了 TensorFlow 2.15(或目前为 `tf-nightly`)中提供的两个新选项:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "vMvEKDFOsau7"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q tf_nightly"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a6hOFBfPsd3y"
+ },
+ "source": [
+ "**注**:`experimental_enable_numpy_behavior` 会更改所有 TensorFlow 的行为。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ob1HNwUmYR5b"
+ },
+ "source": [
+ "## 安装"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "AJR558zjAZQu"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import tensorflow as tf\n",
+ "import tensorflow.experimental.numpy as tnp\n",
+ "\n",
+ "print(\"Using TensorFlow version %s\" % tf.__version__)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "M6tacoy0DU6e"
+ },
+ "source": [
+ "### 启用新类型提升\n",
+ "\n",
+ "为了在 TF-Numpy 中使用[类似 JAX 的类型提升](https://jax.readthedocs.io/en/latest/type_promotion.html),请在为 TensorFlow 启用 NumPy 行为时指定 `'all'` 或 `'safe'` 作为数据类型转换模式。\n",
+ "\n",
+ "此新系统 (`dtype_conversion_mode=\"all\"`) 可结合、可交换,并且可以轻松控制最终的浮点数宽度(它不会自动转换为更宽的浮点数)。它确实引入了一些溢出和精度损失的风险,但 `dtype_conversion_mode=\"safe\"` 会强制您显式处理这些情况。[下一部分](#two_modes)将更详细地解释这两种模式。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "TfCyofpFDQxm"
+ },
+ "outputs": [],
+ "source": [
+ "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "sEMXK8-ZWMun"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "## 两种模式:ALL 模式与 SAFE 模式\n",
+ "\n",
+ "在新提升系统中,我们引入了两种模式:`ALL` 模式和 `SAFE` 模式。`SAFE` 模式用于减轻可能导致精度损失或位加宽的“风险”提升的担忧。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-ULvTWj_KnHU"
+ },
+ "source": [
+ "### 数据类型\n",
+ "\n",
+ "为简洁起见,我们将使用以下缩写。\n",
+ "\n",
+ "- `b` 表示 `tf.bool`\n",
+ "- `u8` 表示 `tf.uint8`\n",
+ "- `i16` 表示 `tf.int16`\n",
+ "- `i32` 表示 `tf.int32`\n",
+ "- `bf16` 表示 `tf.bfloat16`\n",
+ "- `f32` 表示 `tf.float32`\n",
+ "- `f64` 表示 `tf.float64`\n",
+ "- `i32*` 表示 Python `int` 或弱类型 `i32`\n",
+ "- `f32*` 表示 Python `float` 浮点型或弱类型 `f32`\n",
+ "- `c128*` 表示 Python `complex` 或弱类型 `c128`\n",
+ "\n",
+ "星号 (*) 表示相应的类型是“弱类型”- 此类数据类型是由系统临时推断的,可以遵从其他数据类型。[此处](#weak_tensor)更详细地解释了这个概念。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hXZxLCkuzzq3"
+ },
+ "source": [
+ "### 精度损失运算示例\n",
+ "\n",
+ "在以下示例中,`ALL` 模式下允许使用 `i32` + `f32`,但由于精度损失的风险,`SAFE` 模式下不允许使用。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Y-yeIvstWStL"
+ },
+ "outputs": [],
+ "source": [
+ "# i32 + f32 returns a f32 result in ALL mode.\n",
+ "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")\n",
+ "a = tf.constant(10, dtype = tf.int32)\n",
+ "b = tf.constant(5.0, dtype = tf.float32)\n",
+ "a + b # "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "JNNmZow2WY3G"
+ },
+ "outputs": [],
+ "source": [
+ "# This promotion is not allowed in SAFE mode.\n",
+ "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"safe\")\n",
+ "a = tf.constant(10, dtype = tf.int32)\n",
+ "b = tf.constant(5.0, dtype = tf.float32)\n",
+ "try:\n",
+ " a + b\n",
+ "except TypeError as e:\n",
+ " print(f'{type(e)}: {e}') # TypeError: explicitly specify the dtype or switch to ALL mode."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "f0x4Qhff0AKS"
+ },
+ "source": [
+ "### 位加宽运算示例\n",
+ "\n",
+ "在以下示例中,ALL 模式下允许使用 `i8` + `u32`,但由于位加宽,SAFE 模式下不允许使用,这意味着使用的位数多于输入中的位数。请注意,新的类型提升语义仅允许必要的位加宽。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Etbv-WoWzUXf"
+ },
+ "outputs": [],
+ "source": [
+ "# i8 + u32 returns an i64 result in ALL mode.\n",
+ "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")\n",
+ "a = tf.constant(10, dtype = tf.int8)\n",
+ "b = tf.constant(5, dtype = tf.uint32)\n",
+ "a + b"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yKRdvtvw0Lvt"
+ },
+ "outputs": [],
+ "source": [
+ "# This promotion is not allowed in SAFE mode.\n",
+ "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"safe\")\n",
+ "a = tf.constant(10, dtype = tf.int8)\n",
+ "b = tf.constant(5, dtype = tf.uint32)\n",
+ "try:\n",
+ " a + b\n",
+ "except TypeError as e:\n",
+ " print(f'{type(e)}: {e}') # TypeError: explicitly specify the dtype or switch to ALL mode."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yh2BwqUzH3C3"
+ },
+ "source": [
+ "## 基于点阵的系统"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HHUnfTPiYVN5"
+ },
+ "source": [
+ "### 类型提升点阵\n",
+ "\n",
+ "新的类型提升行为通过以下类型提升点阵来确定:\n",
+ "\n",
+ "![Type Promotion Lattice](https://tensorflow.org/guide/images/new_type_promotion/type_promotion_lattice.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QykluwRyDDle"
+ },
+ "source": [
+ "更具体地说,任何两种类型之间的提升是通过查找两个节点的第一个公共子节点(包括节点本身)来确定的。\n",
+ "\n",
+ "例如,在上图中,`i8` 和 `i32` 的第一个公共子节点是 `i32`,因为沿着箭头方向,这两个节点在 `i32` 处第一次相交。\n",
+ "\n",
+ "类似地,在另一个示例中,`u64` 和 `f16` 之间的结果提升类型为 `f16`。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nthziRHaDAUY"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "### 类型提升表\n",
+ "\n",
+ "按照点阵行进会生成下面的二进制提升表:\n",
+ "\n",
+ "**注**:`SAFE` 不允许高亮显示的单元格。`ALL` 模式允许全部情况。\n",
+ "\n",
+ "![Type Promotion Table](https://tensorflow.org/guide/images/new_type_promotion/type_promotion_table.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TPDt5QTkucSC"
+ },
+ "source": [
+ "## 新类型提升的优点\n",
+ "\n",
+ "我们针对新类型提升采用类似 JAX 的基于点阵的系统,它具有以下优点:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NUS_b13nue1p"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "#### 基于点阵的系统的优点\n",
+ "\n",
+ "首先,使用基于点阵的系统可以确保三个非常重要的属性:\n",
+ "\n",
+ "- 存在性:任何类型的组合都存在唯一的结果提升类型。\n",
+ "- 交换性:`a + b = b + a`\n",
+ "- 结合性:`a + (b + c) = (a + b) = c`\n",
+ "\n",
+ "这三个属性对于构建一致且可预测的类型提升语义至关重要。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Sz88hRR6uhls"
+ },
+ "source": [
+ "#### 类似 JAX 的点阵系统的优点\n",
+ "\n",
+ "类似 JAX 的点阵系统的另一个重要优点是,除了无符号整数之外,它避免了所有超出必要范围的提升。这意味着没有 64 位输入就无法获得 64 位结果。这对于加速器上的工作特别有利,因为它可以避免不必要的 64 位值,这在旧类型提升中十分常见。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rlylb7ieOVbJ"
+ },
+ "source": [
+ "不过,这需要一定的权衡:混合浮点/整数提升很容易导致精度损失。例如,在下面的示例中,`i64` + `f16` 会导致将 `i64` 提升为 `f16`。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "abqIkV02OXEF"
+ },
+ "outputs": [],
+ "source": [
+ "# The first input is promoted to f16 in ALL mode.\n",
+ "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")\n",
+ "tf.constant(1, tf.int64) + tf.constant(3.2, tf.float16) # "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mYnh1gZdObfI"
+ },
+ "source": [
+ "为了缓解这种担忧,我们引入了 `SAFE` 模式,此模式会禁止这些“风险”提升。\n",
+ "\n",
+ "**注**:要详细了解构造点阵系统的设计注意事项,请参阅 [JAX 的类型提升语义设计](https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html)。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gAc7LFV0S2dP"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "## WeakTensor"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "olQ2gsFlS9BH"
+ },
+ "source": [
+ "### 概述\n",
+ "\n",
+ "*WeakTensor* 是“弱类型”的张量,类似于 [JAX 中的概念](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax)。\n",
+ "\n",
+ "`WeakTensor` 的数据类型是由系统临时推断的,并且可以遵从其他数据类型。在新类型提升中引入此概念的目的是防止 TF 值与没有用户显式指定类型的值(例如 Python 标量文字)之间的二进制运算中出现不需要的类型提升。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MYmoFIqZTFtw"
+ },
+ "source": [
+ "例如,在下面的示例中,`tf.constant(1.2)` 被视为“弱”,因为它没有特定的数据类型。因此,`tf.constant(1.2)` 遵从 `tf.constant(3.1, tf.float16)` 的类型,产生 `f16` 输出。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "eSBv_mzyTE97"
+ },
+ "outputs": [],
+ "source": [
+ "tf.constant(1.2) + tf.constant(3.1, tf.float16) # "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KxuqBIFuTm5Z"
+ },
+ "source": [
+ "### WeakTensor 构造\n",
+ "\n",
+ "如果您创建张量而不指定数据类型,则会创建 WeakTensor。可以通过检查张量字符串表示末尾的弱特性来检查张量是否为“弱”张量。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7UmunnJ8Tru3"
+ },
+ "source": [
+ "**第一种情况**:使用没有用户指定数据类型的输入调用 `tf.constant` 时。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "fLEtMluNTsI5"
+ },
+ "outputs": [],
+ "source": [
+ "tf.constant(5) # "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ZQX6MBWHTt__"
+ },
+ "outputs": [],
+ "source": [
+ "tf.constant([5.0, 10.0, 3]) # "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ftsKSC5BTweP"
+ },
+ "outputs": [],
+ "source": [
+ "# A normal Tensor is created when dtype arg is specified.\n",
+ "tf.constant(5, tf.int32) # "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RqhoRy5iTyag"
+ },
+ "source": [
+ "**第二种情况**:当没有用户指定数据类型的输入被传递到[支持 WeakTensor 的 API](#weak_tensor_apis) 中时。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DuwpgoQJTzE-"
+ },
+ "outputs": [],
+ "source": [
+ "tf.math.abs([100.0, 4.0]) # "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UTcoR1xvR39k"
+ },
+ "source": [
+ "##开启新类型提升的效果\n",
+ "\n",
+ "以下是由于开启新类型提升而引起的更改的非详尽清单。\n",
+ "\n",
+ "- 提升结果更一致且可预测。\n",
+ "- 降低位加宽的风险。\n",
+ "- `tf.Tensor` 数学 dunder 方法使用新类型提升。\n",
+ "- `tf.constant` 可以返回 `WeakTensor`。\n",
+ "- 当传入一个数据类型与 `dtype` 参数不同的张量输入时,`tf.constant` 允许隐式转换。\n",
+ "- `tf.Variable` 就地运算(`assign`、`assign-add`、`assign-sub`)允许隐式转换。\n",
+ "- `tnp.array(1)` 和 `tnp.array(1.0)` 返回 32 位 WeakTensor。\n",
+ "- 将创建 `WeakTensor` 用于[支持 WeakTensor 的一元和二元 API](#weak_tensor_apis)。\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KyvonwYcsFX2"
+ },
+ "source": [
+ "### 提升结果更一致且可预测性提升\n",
+ "\n",
+ "使用[基于点阵的系统](#lattice_system_design)允许新类型提升产生一致且可预测的类型提升结果。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "q0Z1njfb7lRa"
+ },
+ "source": [
+ "#### 旧类型提升\n",
+ "\n",
+ "使用旧类型提升更改运算顺序会产生不一致的结果。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "M1Ca9v4m7z8e"
+ },
+ "outputs": [],
+ "source": [
+ "# Setup\n",
+ "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"legacy\")\n",
+ "a = np.array(1, dtype=np.int8)\n",
+ "b = tf.constant(1)\n",
+ "c = np.array(1, dtype=np.float16)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "WwhTzJ-a4rTc"
+ },
+ "outputs": [],
+ "source": [
+ "# (a + b) + c throws an InvalidArgumentError.\n",
+ "try:\n",
+ " tf.add(tf.add(a, b), c)\n",
+ "except tf.errors.InvalidArgumentError as e:\n",
+ " print(f'{type(e)}: {e}') # InvalidArgumentError"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "d3qDgVYn7ezT"
+ },
+ "outputs": [],
+ "source": [
+ "# (b + a) + c returns an i32 result.\n",
+ "tf.add(tf.add(b, a), c) # "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YMH1skEs7oI5"
+ },
+ "source": [
+ "#### 新类型提升\n",
+ "\n",
+ "无论顺序如何,新类型提升都会产生一致的结果。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "BOHyJJ8z8uCN"
+ },
+ "outputs": [],
+ "source": [
+ "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")\n",
+ "a = np.array(1, dtype=np.int8)\n",
+ "b = tf.constant(1)\n",
+ "c = np.array(1, dtype=np.float16)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ZUKU70jf7E1l"
+ },
+ "outputs": [],
+ "source": [
+ "# (a + b) + c returns a f16 result.\n",
+ "tf.add(tf.add(a, b), c) # "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "YOEycjFx7qDn"
+ },
+ "outputs": [],
+ "source": [
+ "# (b + a) + c also returns a f16 result.\n",
+ "tf.add(tf.add(b, a), c) # "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FpGMkm6aJsn6"
+ },
+ "source": [
+ "### 降低位加宽的风险"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JxV2AL-U9Grg"
+ },
+ "source": [
+ "#### 旧类型提升\n",
+ "\n",
+ "旧类型提升通常会产生 64 位结果。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7L1pxyvn9MlP"
+ },
+ "outputs": [],
+ "source": [
+ "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"legacy\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "zMJVFdWf4XHp"
+ },
+ "outputs": [],
+ "source": [
+ "np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50) # "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fBhUH_wD9Is7"
+ },
+ "source": [
+ "#### 新类型提升\n",
+ "\n",
+ "新类型提升返回所需位数最少的结果。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "aJsj2ZyI9T9Y"
+ },
+ "outputs": [],
+ "source": [
+ "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jj0N_Plp4X9l"
+ },
+ "outputs": [],
+ "source": [
+ "np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50) # "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yKUx7xe-KZ5O"
+ },
+ "source": [
+ "### tf.Tensor 数学 dunder 方法\n",
+ "\n",
+ "所有 `tf.Tensor` 数学 dunder 方法都将遵循新类型提升。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2c3icBUX4wNl"
+ },
+ "outputs": [],
+ "source": [
+ "-tf.constant(5) # "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ydJHQjid45s7"
+ },
+ "outputs": [],
+ "source": [
+ "tf.constant(5, tf.int16) - tf.constant(1, tf.float32) # "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "pLbIjIvbKqcU"
+ },
+ "source": [
+ "### tf.Variable 就地运算\n",
+ "\n",
+ "`tf.Variable` 就地运算中允许隐式转换。\n",
+ "\n",
+ "**注**:任何导致数据类型与变量的原始数据类型不同的提升都是不允许的。原因是 `tf.Variable` 不能更改其数据类型。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "QsXhyK1h-i5S"
+ },
+ "outputs": [],
+ "source": [
+ "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")\n",
+ "a = tf.Variable(10, tf.int32)\n",
+ "a.assign_add(tf.constant(5, tf.int16)) # "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PiA4H-otLDit"
+ },
+ "source": [
+ "### tf.constant 隐式转换\n",
+ "\n",
+ "在旧类型提升中,`tf.constant` 要求输入张量与数据类型参数具有相同的数据类型。不过,在新类型提升中,我们将张量隐式转换为指定的数据类型。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ArrQ9Dj0_OR8"
+ },
+ "outputs": [],
+ "source": [
+ "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")\n",
+ "a = tf.constant(10, tf.int16)\n",
+ "tf.constant(a, tf.float32) # "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WAcK_-XnLWaP"
+ },
+ "source": [
+ "### TF-NumPy 数组\n",
+ "\n",
+ "对于使用新类型提升的 Python 输入,`tnp.array` 默认为 `i32*` 和 `f32*`。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "K1pZnYNh_ahm"
+ },
+ "outputs": [],
+ "source": [
+ "tnp.array(1) # "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "QoQl2PYP_fMT"
+ },
+ "outputs": [],
+ "source": [
+ "tnp.array(1.0) # "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "wK5DpQ3Pz3k5"
+ },
+ "source": [
+ "##输入类型推断\n",
+ "\n",
+ "下面是在新类型提升中推断不同输入类型的方式。\n",
+ "\n",
+ "- `tf.Tensor`:由于 `tf.Tensor` 具有数据类型属性,我们不做进一步的推断。\n",
+ "- NumPy 类型:包括 `np.array(1)`、`np.int16(1)` 和 `np.float` 等类型。由于 NumPy 输入也具有数据类型属性,我们将数据类型属性作为结果推断类型。请注意,NumPy 默认为 `i64` 和 `f64`。\n",
+ "- Python 标量/嵌套类型:包括 `1`、`[1, 2, 3]` 和 `(1.0, 2.0)` 等类型。\n",
+ " - Python `int` 被推断为 `i32*`。\n",
+ " - Python `float` 被推断为 `f32*`。\n",
+ " - Python `complex` 被推断为 `c128*`。\n",
+ "- 如果输入不属于上述任何类别,但具有数据类型属性,我们将数据类型属性作为结果推断类型。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "g_SPfalfSPgg"
+ },
+ "source": [
+ "# 延伸阅读\n",
+ "\n",
+ "新类型提升与 JAX-NumPy 的类型提升非常相似。如果想了解有关新类型提升和设计选择的更多详细信息,请查阅以下资源。\n",
+ "\n",
+ "- [JAX 类型提升语义](https://jax.readthedocs.io/en/latest/type_promotion.html)\n",
+ "- [JAX 类型提升语义设计](https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html)\n",
+ "- [旧 TF-NumPy 提升语义](https://tensorflow.google.cn/guide/tf_numpy#type_promotion)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Qg5xBbImT31S"
+ },
+ "source": [
+ "# 参考"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gjB0CVhVXBfW"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "## 支持 WeakTensor 的 API"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_GVbqlN9aBS2"
+ },
+ "source": [
+ "以下是支持 `WeakTensor` 的 API 列表。\n",
+ "\n",
+ "对于一元运算,这意味着如果传入没有用户指定类型的输入,它将返回 `WeakTensor`。\n",
+ "\n",
+ "对于二元运算,它将遵循[此处](#promotion_table)的提升表。它可能会也可能不会返回 `WeakTensor`,具体取决于两个输入的提升结果。\n",
+ "\n",
+ "**注**:支持所有数学运算(`+`、`-`、`*`、...)。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Gi-G68Z8WN2P"
+ },
+ "source": [
+ "- `tf.bitwise.invert`\n",
+ "- `tf.clip_by_value`\n",
+ "- `tf.debugging.check_numerics`\n",
+ "- `tf.expand_dims`\n",
+ "- `tf.identity`\n",
+ "- `tf.image.adjust_brightness`\n",
+ "- `tf.image.adjust_gamma`\n",
+ "- `tf.image.extract_patches`\n",
+ "- `tf.image.random_brightness`\n",
+ "- `tf.image.stateless_random_brightness`\n",
+ "- `tf.linalg.diag`\n",
+ "- `tf.linalg.diag_part`\n",
+ "- `tf.linalg.matmul`\n",
+ "- `tf.linalg.matrix_transpose`\n",
+ "- `tf.linalg.tensor_diag_part`\n",
+ "- `tf.linalg.trace`\n",
+ "- `tf.math.abs`\n",
+ "- `tf.math.acos`\n",
+ "- `tf.math.acosh`\n",
+ "- `tf.math.add`\n",
+ "- `tf.math.angle`\n",
+ "- `tf.math.asin`\n",
+ "- `tf.math.asinh`\n",
+ "- `tf.math.atan`\n",
+ "- `tf.math.atanh`\n",
+ "- `tf.math.ceil`\n",
+ "- `tf.math.conj`\n",
+ "- `tf.math.cos`\n",
+ "- `tf.math.cosh`\n",
+ "- `tf.math.digamma`\n",
+ "- `tf.math.divide_no_nan`\n",
+ "- `tf.math.divide`\n",
+ "- `tf.math.erf`\n",
+ "- `tf.math.erfc`\n",
+ "- `tf.math.erfcinv`\n",
+ "- `tf.math.erfinv`\n",
+ "- `tf.math.exp`\n",
+ "- `tf.math.expm1`\n",
+ "- `tf.math.floor`\n",
+ "- `tf.math.floordiv`\n",
+ "- `tf.math.floormod`\n",
+ "- `tf.math.imag`\n",
+ "- `tf.math.lgamma`\n",
+ "- `tf.math.log1p`\n",
+ "- `tf.math.log_sigmoid`\n",
+ "- `tf.math.log`\n",
+ "- `tf.math.multiply_no_nan`\n",
+ "- `tf.math.multiply`\n",
+ "- `tf.math.ndtri`\n",
+ "- `tf.math.negative`\n",
+ "- `tf.math.pow`\n",
+ "- `tf.math.real`\n",
+ "- `tf.math.real`\n",
+ "- `tf.math.reciprocal_no_nan`\n",
+ "- `tf.math.reciprocal`\n",
+ "- `tf.math.reduce_euclidean_norm`\n",
+ "- `tf.math.reduce_logsumexp`\n",
+ "- `tf.math.reduce_max`\n",
+ "- `tf.math.reduce_mean`\n",
+ "- `tf.math.reduce_min`\n",
+ "- `tf.math.reduce_prod`\n",
+ "- `tf.math.reduce_std`\n",
+ "- `tf.math.reduce_sum`\n",
+ "- `tf.math.reduce_variance`\n",
+ "- `tf.math.rint`\n",
+ "- `tf.math.round`\n",
+ "- `tf.math.rsqrt`\n",
+ "- `tf.math.scalar_mul`\n",
+ "- `tf.math.sigmoid`\n",
+ "- `tf.math.sign`\n",
+ "- `tf.math.sin`\n",
+ "- `tf.math.sinh`\n",
+ "- `tf.math.softplus`\n",
+ "- `tf.math.special.bessel_i0`\n",
+ "- `tf.math.special.bessel_i0e`\n",
+ "- `tf.math.special.bessel_i1`\n",
+ "- `tf.math.special.bessel_i1e`\n",
+ "- `tf.math.special.bessel_j0`\n",
+ "- `tf.math.special.bessel_j1`\n",
+ "- `tf.math.special.bessel_k0`\n",
+ "- `tf.math.special.bessel_k0e`\n",
+ "- `tf.math.special.bessel_k1`\n",
+ "- `tf.math.special.bessel_k1e`\n",
+ "- `tf.math.special.bessel_y0`\n",
+ "- `tf.math.special.bessel_y1`\n",
+ "- `tf.math.special.dawsn`\n",
+ "- `tf.math.special.expint`\n",
+ "- `tf.math.special.fresnel_cos`\n",
+ "- `tf.math.special.fresnel_sin`\n",
+ "- `tf.math.special.spence`\n",
+ "- `tf.math.sqrt`\n",
+ "- `tf.math.square`\n",
+ "- `tf.math.subtract`\n",
+ "- `tf.math.tan`\n",
+ "- `tf.math.tanh`\n",
+ "- `tf.nn.depth_to_space`\n",
+ "- `tf.nn.elu`\n",
+ "- `tf.nn.gelu`\n",
+ "- `tf.nn.leaky_relu`\n",
+ "- `tf.nn.log_softmax`\n",
+ "- `tf.nn.relu6`\n",
+ "- `tf.nn.relu`\n",
+ "- `tf.nn.selu`\n",
+ "- `tf.nn.softsign`\n",
+ "- `tf.nn.space_to_depth`\n",
+ "- `tf.nn.swish`\n",
+ "- `tf.ones_like`\n",
+ "- `tf.realdiv`\n",
+ "- `tf.reshape`\n",
+ "- `tf.squeeze`\n",
+ "- `tf.stop_gradient`\n",
+ "- `tf.transpose`\n",
+ "- `tf.truncatediv`\n",
+ "- `tf.truncatemod`\n",
+ "- `tf.zeros_like`\n",
+ "- `tf.experimental.numpy.abs`\n",
+ "- `tf.experimental.numpy.absolute`\n",
+ "- `tf.experimental.numpy.amax`\n",
+ "- `tf.experimental.numpy.amin`\n",
+ "- `tf.experimental.numpy.angle`\n",
+ "- `tf.experimental.numpy.arange`\n",
+ "- `tf.experimental.numpy.arccos`\n",
+ "- `tf.experimental.numpy.arccosh`\n",
+ "- `tf.experimental.numpy.arcsin`\n",
+ "- `tf.experimental.numpy.arcsinh`\n",
+ "- `tf.experimental.numpy.arctan`\n",
+ "- `tf.experimental.numpy.arctanh`\n",
+ "- `tf.experimental.numpy.around`\n",
+ "- `tf.experimental.numpy.array`\n",
+ "- `tf.experimental.numpy.asanyarray`\n",
+ "- `tf.experimental.numpy.asarray`\n",
+ "- `tf.experimental.numpy.ascontiguousarray`\n",
+ "- `tf.experimental.numpy.average`\n",
+ "- `tf.experimental.numpy.bitwise_not`\n",
+ "- `tf.experimental.numpy.cbrt`\n",
+ "- `tf.experimental.numpy.ceil`\n",
+ "- `tf.experimental.numpy.conj`\n",
+ "- `tf.experimental.numpy.conjugate`\n",
+ "- `tf.experimental.numpy.copy`\n",
+ "- `tf.experimental.numpy.cos`\n",
+ "- `tf.experimental.numpy.cosh`\n",
+ "- `tf.experimental.numpy.cumprod`\n",
+ "- `tf.experimental.numpy.cumsum`\n",
+ "- `tf.experimental.numpy.deg2rad`\n",
+ "- `tf.experimental.numpy.diag`\n",
+ "- `tf.experimental.numpy.diagflat`\n",
+ "- `tf.experimental.numpy.diagonal`\n",
+ "- `tf.experimental.numpy.diff`\n",
+ "- `tf.experimental.numpy.empty_like`\n",
+ "- `tf.experimental.numpy.exp2`\n",
+ "- `tf.experimental.numpy.exp`\n",
+ "- `tf.experimental.numpy.expand_dims`\n",
+ "- `tf.experimental.numpy.expm1`\n",
+ "- `tf.experimental.numpy.fabs`\n",
+ "- `tf.experimental.numpy.fix`\n",
+ "- `tf.experimental.numpy.flatten`\n",
+ "- `tf.experimental.numpy.flip`\n",
+ "- `tf.experimental.numpy.fliplr`\n",
+ "- `tf.experimental.numpy.flipud`\n",
+ "- `tf.experimental.numpy.floor`\n",
+ "- `tf.experimental.numpy.full_like`\n",
+ "- `tf.experimental.numpy.imag`\n",
+ "- `tf.experimental.numpy.log10`\n",
+ "- `tf.experimental.numpy.log1p`\n",
+ "- `tf.experimental.numpy.log2`\n",
+ "- `tf.experimental.numpy.log`\n",
+ "- `tf.experimental.numpy.max`\n",
+ "- `tf.experimental.numpy.mean`\n",
+ "- `tf.experimental.numpy.min`\n",
+ "- `tf.experimental.numpy.moveaxis`\n",
+ "- `tf.experimental.numpy.nanmean`\n",
+ "- `tf.experimental.numpy.negative`\n",
+ "- `tf.experimental.numpy.ones_like`\n",
+ "- `tf.experimental.numpy.positive`\n",
+ "- `tf.experimental.numpy.prod`\n",
+ "- `tf.experimental.numpy.rad2deg`\n",
+ "- `tf.experimental.numpy.ravel`\n",
+ "- `tf.experimental.numpy.real`\n",
+ "- `tf.experimental.numpy.reciprocal`\n",
+ "- `tf.experimental.numpy.repeat`\n",
+ "- `tf.experimental.numpy.reshape`\n",
+ "- `tf.experimental.numpy.rot90`\n",
+ "- `tf.experimental.numpy.round`\n",
+ "- `tf.experimental.numpy.signbit`\n",
+ "- `tf.experimental.numpy.sin`\n",
+ "- `tf.experimental.numpy.sinc`\n",
+ "- `tf.experimental.numpy.sinh`\n",
+ "- `tf.experimental.numpy.sort`\n",
+ "- `tf.experimental.numpy.sqrt`\n",
+ "- `tf.experimental.numpy.square`\n",
+ "- `tf.experimental.numpy.squeeze`\n",
+ "- `tf.experimental.numpy.std`\n",
+ "- `tf.experimental.numpy.sum`\n",
+ "- `tf.experimental.numpy.swapaxes`\n",
+ "- `tf.experimental.numpy.tan`\n",
+ "- `tf.experimental.numpy.tanh`\n",
+ "- `tf.experimental.numpy.trace`\n",
+ "- `tf.experimental.numpy.transpose`\n",
+ "- `tf.experimental.numpy.triu`\n",
+ "- `tf.experimental.numpy.vander`\n",
+ "- `tf.experimental.numpy.var`\n",
+ "- `tf.experimental.numpy.zeros_like`"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "name": "tf_numpy_type_promotion.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/site/zh-cn/guide/tpu.ipynb b/site/zh-cn/guide/tpu.ipynb
index a4f465bab2..31bd25f531 100644
--- a/site/zh-cn/guide/tpu.ipynb
+++ b/site/zh-cn/guide/tpu.ipynb
@@ -41,8 +41,9 @@
"\n",
""
]
@@ -64,7 +65,7 @@
"id": "ek5Hop74NVKm"
},
"source": [
- "## 设置"
+ "## 安装"
]
},
{
@@ -246,13 +247,32 @@
"outputs": [],
"source": [
"def create_model():\n",
+ " regularizer = tf.keras.regularizers.L2(1e-5)\n",
" return tf.keras.Sequential(\n",
- " [tf.keras.layers.Conv2D(256, 3, activation='relu', input_shape=(28, 28, 1)),\n",
- " tf.keras.layers.Conv2D(256, 3, activation='relu'),\n",
+ " [tf.keras.layers.Conv2D(256, 3, input_shape=(28, 28, 1),\n",
+ " activation='relu',\n",
+ " kernel_regularizer=regularizer),\n",
+ " tf.keras.layers.Conv2D(256, 3,\n",
+ " activation='relu',\n",
+ " kernel_regularizer=regularizer),\n",
" tf.keras.layers.Flatten(),\n",
- " tf.keras.layers.Dense(256, activation='relu'),\n",
- " tf.keras.layers.Dense(128, activation='relu'),\n",
- " tf.keras.layers.Dense(10)])"
+ " tf.keras.layers.Dense(256,\n",
+ " activation='relu',\n",
+ " kernel_regularizer=regularizer),\n",
+ " tf.keras.layers.Dense(128,\n",
+ " activation='relu',\n",
+ " kernel_regularizer=regularizer),\n",
+ " tf.keras.layers.Dense(10,\n",
+ " kernel_regularizer=regularizer)])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "h-2qaXgfyONQ"
+ },
+ "source": [
+ "此模型将 L2 正则化项放在每层的权重上,以便下面的自定义训练循环可以显示如何从 `Model.losses` 中选取它们。"
]
},
{
@@ -434,9 +454,13 @@
" images, labels = inputs\n",
" with tf.GradientTape() as tape:\n",
" logits = model(images, training=True)\n",
- " loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
+ " per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
" labels, logits, from_logits=True)\n",
- " loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)\n",
+ " loss = tf.nn.compute_average_loss(per_example_loss)\n",
+ " model_losses = model.losses\n",
+ " if model_losses:\n",
+ " loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))\n",
+ "\n",
" grads = tape.gradient(loss, model.trainable_variables)\n",
" optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))\n",
" training_loss.update_state(loss * strategy.num_replicas_in_sync)\n",
@@ -470,7 +494,7 @@
"\n",
" for step in range(steps_per_epoch):\n",
" train_step(train_iterator)\n",
- " print('Current step: {}, training loss: {}, accuracy: {}%'.format(\n",
+ " print('Current step: {}, training loss: {}, training accuracy: {}%'.format(\n",
" optimizer.iterations.numpy(),\n",
" round(float(training_loss.result()), 4),\n",
" round(float(training_accuracy.result()) * 100, 2)))\n",
@@ -508,9 +532,12 @@
" images, labels = inputs\n",
" with tf.GradientTape() as tape:\n",
" logits = model(images, training=True)\n",
- " loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
+ " per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
" labels, logits, from_logits=True)\n",
- " loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)\n",
+ " loss = tf.nn.compute_average_loss(per_example_loss)\n",
+ " model_losses = model.losses\n",
+ " if model_losses:\n",
+ " loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))\n",
" grads = tape.gradient(loss, model.trainable_variables)\n",
" optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))\n",
" training_loss.update_state(loss * strategy.num_replicas_in_sync)\n",
@@ -523,7 +550,7 @@
"# retraced if the value changes.\n",
"train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))\n",
"\n",
- "print('Current step: {}, training loss: {}, accuracy: {}%'.format(\n",
+ "print('Current step: {}, training loss: {}, training accuracy: {}%'.format(\n",
" optimizer.iterations.numpy(),\n",
" round(float(training_loss.result()), 4),\n",
" round(float(training_accuracy.result()) * 100, 2)))"
diff --git a/site/zh-cn/guide/versions.md b/site/zh-cn/guide/versions.md
index 7aebf9c5f9..72fa06e4aa 100644
--- a/site/zh-cn/guide/versions.md
+++ b/site/zh-cn/guide/versions.md
@@ -29,7 +29,14 @@ TensorFlow 的公开 API 遵循语义化版本控制 2.0 ([semver](http://semver
- 兼容性 API(在 Python 中,为 `tf.compat` 模块)。在主要版本中,我们可能会发布实用工具和其他端点来帮助用户过渡到新的主要版本。这些 API 符号已弃用且不受支持(即,我们不会添加任何功能,除了修复一些漏洞外,也不会修复错误),但它们在我们的兼容性保证范围内。
-- [C API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h)。
+- TensorFlow C API:
+
+ - [tensorflow/c/c_api.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h)
+
+- TensorFlow Lite C API:
+
+ - [tensorflow/lite/c/c_api.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/c_api.h)
+ - [tensorflow/lite/c/c_api_types.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/c_api_types.h)
- 下列协议缓冲区文件:
@@ -59,7 +66,7 @@ TensorFlow 的某些部分可能随时以向后不兼容的方式更改。包括
- **其他语言**:Python 和 C 以外的其他语言中的 TensorFlow API,例如:
- [C++](../install/lang_c.ipynb)(通过 [`tensorflow/cc`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/cc) 中的头文件公开)
- - [Java](../install/lang_java_legacy.md)
+ - [Java](../install/lang_java_legacy.md),
- [Go](https://github.com/tensorflow/build/blob/master/golang_install_guide/README.md)
- [JavaScript](https://tensorflow.google.cn/js)
diff --git a/site/zh-cn/hub/common_saved_model_apis/images.md b/site/zh-cn/hub/common_saved_model_apis/images.md
index c35a50af52..24db4a5cdd 100644
--- a/site/zh-cn/hub/common_saved_model_apis/images.md
+++ b/site/zh-cn/hub/common_saved_model_apis/images.md
@@ -1,5 +1,3 @@
-
-
# 图像任务的通用 SavedModel API
本页面介绍用于图像相关任务的 [TF2 SavedModel](../tf2_saved_model.md) 应当如何实现[可重用的 SavedModel API](../reusable_saved_models.md)。(这会替换现已弃用的 [TF1 Hub 格式](../tf1_hub_module)的[通用图像签名](../common_signatures/images.md)。)
@@ -46,8 +44,7 @@ features = hub.KerasLayer("path/to/model")(images)
图像特征向量的可重用 SavedModel 在以下各项中使用:
-- Colab 教程[重新训练图像分类器](https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/tf2_image_retraining.ipynb);
-- 命令行工具 [make_image_classifier](https://github.com/tensorflow/hub/tree/master/tensorflow_hub/tools/make_image_classifier)。
+- Colab 教程[重新训练图像分类器](https://colab.research.google.com/github/tensorflow/docs/blob/master/g3doc/en/hub/tutorials/tf2_image_retraining.ipynb);
diff --git a/site/zh-cn/hub/common_saved_model_apis/text.md b/site/zh-cn/hub/common_saved_model_apis/text.md
index c617816313..f2b7d2cbdf 100644
--- a/site/zh-cn/hub/common_saved_model_apis/text.md
+++ b/site/zh-cn/hub/common_saved_model_apis/text.md
@@ -1,5 +1,3 @@
-
-
# 文本任务的通用 SavedModel API
本页面介绍用于文本相关任务的 [TF2 SavedModel](../tf2_saved_model.md) 应当如何实现[可重用的 SavedModel API](../reusable_saved_models.md)。(这会替换现已弃用的 [TF1 Hub 格式](../common_signatures/text.md)的[通用文本签名](../tf1_hub_module)。)
@@ -64,7 +62,7 @@ embeddings = hub.KerasLayer("path/to/model", trainable=...)(text_input)
### 示例
-- Colab 教程[影评文本分类](https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/tf2_text_classification.ipynb)。
+- Colab 教程[电影评论文本分类](https://colab.research.google.com/github/tensorflow/docs/blob/master/g3doc/en/hub/tutorials/tf2_text_classification.ipynb)。
diff --git a/site/zh-cn/hub/common_signatures/text.md b/site/zh-cn/hub/common_signatures/text.md
index 31249af6db..f5eec9f8d2 100644
--- a/site/zh-cn/hub/common_signatures/text.md
+++ b/site/zh-cn/hub/common_signatures/text.md
@@ -1,5 +1,3 @@
-
-
# 文本的常用签名
本页面介绍应由 [TF1 Hub 格式](../tf1_hub_module.md)的模块为接受文本输入的任务实现的常用签名。(有关 [TF2 SavedModel 格式](../tf2_saved_model.md),请参阅具有类似功能的 [SavedModel API](../common_saved_model_apis/text.md)。)
diff --git a/site/zh-cn/hub/tutorials/tf2_text_classification.ipynb b/site/zh-cn/hub/tutorials/tf2_text_classification.ipynb
new file mode 100644
index 0000000000..7cca1d36bf
--- /dev/null
+++ b/site/zh-cn/hub/tutorials/tf2_text_classification.ipynb
@@ -0,0 +1,565 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Ic4_occAAiAT"
+ },
+ "source": [
+ "##### Copyright 2019 The TensorFlow Hub Authors.\n",
+ "\n",
+ "Licensed under the Apache License, Version 2.0 (the \"License\");"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "both",
+ "id": "ioaprt5q5US7"
+ },
+ "outputs": [],
+ "source": [
+ "# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved.\n",
+ "#\n",
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License.\n",
+ "# =============================================================================="
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "yCl0eTNH5RS3"
+ },
+ "outputs": [],
+ "source": [
+ "#@title MIT License\n",
+ "#\n",
+ "# Copyright (c) 2017 François Chollet # IGNORE_COPYRIGHT: cleared by OSS licensing\n",
+ "#\n",
+ "# Permission is hereby granted, free of charge, to any person obtaining a\n",
+ "# copy of this software and associated documentation files (the \"Software\"),\n",
+ "# to deal in the Software without restriction, including without limitation\n",
+ "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n",
+ "# and/or sell copies of the Software, and to permit persons to whom the\n",
+ "# Software is furnished to do so, subject to the following conditions:\n",
+ "#\n",
+ "# The above copyright notice and this permission notice shall be included in\n",
+ "# all copies or substantial portions of the Software.\n",
+ "#\n",
+ "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
+ "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
+ "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n",
+ "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
+ "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n",
+ "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n",
+ "# DEALINGS IN THE SOFTWARE."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ItXfxkxvosLH"
+ },
+ "source": [
+ "# 电影评论文本分类"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MfBg1C5NB3X0"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Eg62Pmz3o83v"
+ },
+ "source": [
+ "此笔记本利用评论文本将电影评论分类为*正面*或*负面*评价。这是一个*二元*(或二类)分类示例,也是一个重要且应用广泛的机器学习问题。\n",
+ "\n",
+ "我们将使用包含 [Internet Movie Database](https://tensorflow.google.cn/api_docs/python/tf/keras/datasets/imdb) 中的 50,000 条电影评论文本的 [IMDB 数据集](https://www.imdb.com/)。先将这些评论分为两组,其中 25,000 条用于训练,另外 25,000 条用于测试。训练组和测试组是*均衡的*,也就是说其中包含相等数量的正面评价和负面评价。\n",
+ "\n",
+ "此笔记本在 TensorFlow 和 [TensorFlow Hub](https://tensorflow.google.cn/api_docs/python/tf/keras)(一个用于迁移学习的库和平台)中使用高级 API [tf.keras](https://tensorflow.google.cn/hub) 来构建和训练模型。有关使用 `tf.keras` 的更高级文本分类教程,请参阅 [MLCC 文本分类指南](https://developers.google.com/machine-learning/guides/text-classification/)。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qrk8NjzhSBh-"
+ },
+ "source": [
+ "### 更多模型\n",
+ "\n",
+ "[这里](https://tfhub.dev/s?module-type=text-embedding)可以找到更具表现力或性能的模型,您可以使用这些模型来生成文本嵌入向量。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Q4DN769E2O_R"
+ },
+ "source": [
+ "## 安装"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2ew7HTbPpCJH"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "import tensorflow as tf\n",
+ "import tensorflow_hub as hub\n",
+ "import tensorflow_datasets as tfds\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "print(\"Version: \", tf.__version__)\n",
+ "print(\"Eager mode: \", tf.executing_eagerly())\n",
+ "print(\"Hub version: \", hub.__version__)\n",
+ "print(\"GPU is\", \"available\" if tf.config.list_physical_devices('GPU') else \"NOT AVAILABLE\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "iAsKG535pHep"
+ },
+ "source": [
+ "## 下载 IMDB 数据集\n",
+ "\n",
+ "[TensorFlow 数据集](https://github.com/tensorflow/datasets)上提供了 IMDB 数据集。以下代码可将 IMDB 数据集下载到您的计算机(或 Colab 运行时)上:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "zXXx5Oc3pOmN"
+ },
+ "outputs": [],
+ "source": [
+ "train_data, test_data = tfds.load(name=\"imdb_reviews\", split=[\"train\", \"test\"], \n",
+ " batch_size=-1, as_supervised=True)\n",
+ "\n",
+ "train_examples, train_labels = tfds.as_numpy(train_data)\n",
+ "test_examples, test_labels = tfds.as_numpy(test_data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "l50X3GfjpU4r"
+ },
+ "source": [
+ "## 探索数据\n",
+ "\n",
+ "我们花一点时间来了解数据的格式。每个样本都是一个代表电影评论的句子和一个相应的标签。句子未经过任何预处理。标签是一个整数值(0 或 1),其中 0 表示负面评价,1 表示正面评价。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "y8qCnve_-lkO"
+ },
+ "outputs": [],
+ "source": [
+ "print(\"Training entries: {}, test entries: {}\".format(len(train_examples), len(test_examples)))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RnKvHWW4-lkW"
+ },
+ "source": [
+ "我们打印前 10 个样本。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "QtTS4kpEpjbi"
+ },
+ "outputs": [],
+ "source": [
+ "train_examples[:10]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "IFtaCHTdc-GY"
+ },
+ "source": [
+ "我们再打印前 10 个标签。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "tvAjVXOWc6Mj"
+ },
+ "outputs": [],
+ "source": [
+ "train_labels[:10]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LLC02j2g-llC"
+ },
+ "source": [
+ "## 构建模型\n",
+ "\n",
+ "神经网络是通过堆叠层创建的,这需要确定三个主要架构决策:\n",
+ "\n",
+ "- 如何表示文本?\n",
+ "- 在模型中使用多少个层?\n",
+ "- 为每个层使用多少个*隐藏神经元*?\n",
+ "\n",
+ "在本例中,输入数据由句子组成。要预测的标签要么是 0,要么是 1。\n",
+ "\n",
+ "表示文本的一种方法是将句子转换为嵌入向量。我们可以将预训练的文本嵌入向量作为第一层,这有两个优势:\n",
+ "\n",
+ "- 不必担心文本预处理。\n",
+ "- 可以从迁移学习获益。\n",
+ "\n",
+ "对于本示例,我们将使用 [TensorFlow Hub](https://tensorflow.google.cn/hub) 中名为 [google/nnlm-en-dim50/2](https://tfhub.dev/google/nnlm-en-dim50/2) 的模型。\n",
+ "\n",
+ "为了学习本教程,还要测试另外两个模型:\n",
+ "\n",
+ "- [google/nnlm-en-dim50-with-normalization/2](https://tfhub.dev/google/nnlm-en-dim50-with-normalization/2) - 与 [google/nnlm-en-dim50/2](https://tfhub.dev/google/nnlm-en-dim50/2) 相同,但进行了更多文本标准化以移除标点。这样有助于更好地覆盖您的输入文本上词例的词汇内嵌入向量。\n",
+ "- [google/nnlm-en-dim128-with-normalization/2](https://tfhub.dev/google/nnlm-en-dim128-with-normalization/2) - 更大的模型,嵌入向量维度为 128,而不是 50。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "In2nDpTLkgKa"
+ },
+ "source": [
+ "我们先创建一个使用 TensorFlow Hub 模型嵌入语句的 Keras 层,并使用几个输入样本试试效果。请注意,产生的嵌入向量的输出形状是预期的:`(num_examples, embedding_dimension)`。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_NUbzVeYkgcO"
+ },
+ "outputs": [],
+ "source": [
+ "model = \"https://tfhub.dev/google/nnlm-en-dim50/2\"\n",
+ "hub_layer = hub.KerasLayer(model, input_shape=[], dtype=tf.string, trainable=True)\n",
+ "hub_layer(train_examples[:3])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "dfSbV6igl1EH"
+ },
+ "source": [
+ "现在构建整个模型:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "xpKOoWgu-llD"
+ },
+ "outputs": [],
+ "source": [
+ "model = tf.keras.Sequential()\n",
+ "model.add(hub_layer)\n",
+ "model.add(tf.keras.layers.Dense(16, activation='relu'))\n",
+ "model.add(tf.keras.layers.Dense(1))\n",
+ "\n",
+ "model.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6PbKQ6mucuKL"
+ },
+ "source": [
+ "按顺序堆叠层以构建分类器:\n",
+ "\n",
+ "1. 第一层是 TensorFlow Hub 层。该层使用预训练的 SaveModel 将句子映射到其嵌入向量。我们使用的模型 ([google/nnlm-en-dim50/2](https://tfhub.dev/google/nnlm-en-dim50/2)) 可将句子拆分为词例,嵌入每个词例,然后组合嵌入向量。生成的维度是:`(num_examples, embedding_dimension)`。\n",
+ "2. 此定长输出向量通过一个有 16 个隐藏单元的全连接 (`Dense`) 层传输。\n",
+ "3. 最后一层与单个输出节点密集连接。这会输出 logits:真类的对数几率(根据模型而定)。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0XMwnDOp-llH"
+ },
+ "source": [
+ "### 隐藏神经元\n",
+ "\n",
+ "上述模型在输入和输出之间有两个中间(或称“隐藏”)层。输出(单元、节点或神经元)的数量是层的表示空间的维度。换言之,即网络学习内部表示时允许的自由度。\n",
+ "\n",
+ "模型的隐藏单元越多(更高维度的表示空间)和/或层越多,则网络可以学习的表示越复杂。但是,这会导致网络的计算开销增加,并且可能导致学习不需要的模式——提高在训练数据(而不是测试数据)上的性能的模式。这就叫*过拟合*,我们稍后将进行探讨。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "L4EqVWg4-llM"
+ },
+ "source": [
+ "### 损失函数和优化器\n",
+ "\n",
+ "模型训练需要一个损失函数和一个优化器。由于这是二元分类问题,并且模型输出概率(具有 Sigmoid 激活的单一单元层),我们将使用 `binary_crossentropy` 损失函数。\n",
+ "\n",
+ "这并非损失函数的唯一选择,比如,您还可以选择 `mean_squared_error`。但是,一般来说,`binary_crossentropy` 更适合处理概率问题,它可以测量概率分布之间的“距离”,或者在我们的用例中,是指真实分布与预测值之间的差距。\n",
+ "\n",
+ "稍后我们研究回归问题时(比如说,预测一套房子的价格),我们将看到如何使用另一个称为均方误差的损失函数。\n",
+ "\n",
+ "现在,配置模型以使用优化器和损失函数:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Mr0GP-cQ-llN"
+ },
+ "outputs": [],
+ "source": [
+ "model.compile(optimizer='adam',\n",
+ " loss=tf.losses.BinaryCrossentropy(from_logits=True),\n",
+ " metrics=[tf.metrics.BinaryAccuracy(threshold=0.0, name='accuracy')])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hCWYwkug-llQ"
+ },
+ "source": [
+ "## 创建验证集\n",
+ "\n",
+ "训练时,我们希望检验该模型在未见过的数据上的准确率。为此,需要将原始训练数据中的 10,000 个样本分离出来,创建一个*验证集*。(为何现在不使用测试集?因为我们的目标是仅使用训练数据开发和调整模型,然后只使用一次测试数据来评估准确率)。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-NpcXY9--llS"
+ },
+ "outputs": [],
+ "source": [
+ "x_val = train_examples[:10000]\n",
+ "partial_x_train = train_examples[10000:]\n",
+ "\n",
+ "y_val = train_labels[:10000]\n",
+ "partial_y_train = train_labels[10000:]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "35jv_fzP-llU"
+ },
+ "source": [
+ "## 训练模型\n",
+ "\n",
+ "使用包含 512 个样本的 mini-batch 对模型进行 40 个周期的训练,也就是在 `x_train` 和 `y_train` 张量中对所有样本进行 40 次迭代。在训练时,监测模型在验证集的 10,000 个样本上的损失和准确率:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "tXSGrjWZ-llW"
+ },
+ "outputs": [],
+ "source": [
+ "history = model.fit(partial_x_train,\n",
+ " partial_y_train,\n",
+ " epochs=40,\n",
+ " batch_size=512,\n",
+ " validation_data=(x_val, y_val),\n",
+ " verbose=1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9EEGuDVuzb5r"
+ },
+ "source": [
+ "## 评估模型\n",
+ "\n",
+ "我们来看一下模型的性能如何。将返回两个值。损失值(一个表示误差的数字,值越低越好)与准确率。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "zOMKywn4zReN"
+ },
+ "outputs": [],
+ "source": [
+ "results = model.evaluate(test_examples, test_labels)\n",
+ "\n",
+ "print(results)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "z1iEXVTR0Z2t"
+ },
+ "source": [
+ "这种相当简单的方法可以达到 87% 的准确率。对于更高级的方法,模型的准确率应该会接近 95%。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5KggXVeL-llZ"
+ },
+ "source": [
+ "## 创建准确率和损失随时间变化的图表\n",
+ "\n",
+ "`model.fit()` 会返回包含一个字典的 `History` 对象。该字典包含训练过程中产生的所有信息:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "VcvSXvhp-llb"
+ },
+ "outputs": [],
+ "source": [
+ "history_dict = history.history\n",
+ "history_dict.keys()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nRKsqL40-lle"
+ },
+ "source": [
+ "其中有四个条目:每个条目代表训练和验证过程中的一项监测指标。我们可以使用这些指标来绘制用于比较的训练和验证图表,以及训练和验证准确率图表:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nGoYf2Js-lle"
+ },
+ "outputs": [],
+ "source": [
+ "acc = history_dict['accuracy']\n",
+ "val_acc = history_dict['val_accuracy']\n",
+ "loss = history_dict['loss']\n",
+ "val_loss = history_dict['val_loss']\n",
+ "\n",
+ "epochs = range(1, len(acc) + 1)\n",
+ "\n",
+ "# \"bo\" is for \"blue dot\"\n",
+ "plt.plot(epochs, loss, 'bo', label='Training loss')\n",
+ "# b is for \"solid blue line\"\n",
+ "plt.plot(epochs, val_loss, 'b', label='Validation loss')\n",
+ "plt.title('Training and validation loss')\n",
+ "plt.xlabel('Epochs')\n",
+ "plt.ylabel('Loss')\n",
+ "plt.legend()\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "6hXx-xOv-llh"
+ },
+ "outputs": [],
+ "source": [
+ "plt.clf() # clear figure\n",
+ "\n",
+ "plt.plot(epochs, acc, 'bo', label='Training acc')\n",
+ "plt.plot(epochs, val_acc, 'b', label='Validation acc')\n",
+ "plt.title('Training and validation accuracy')\n",
+ "plt.xlabel('Epochs')\n",
+ "plt.ylabel('Accuracy')\n",
+ "plt.legend()\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "oFEmZ5zq-llk"
+ },
+ "source": [
+ "在该图表中,虚线代表训练损失和准确率,实线代表验证损失和准确率。\n",
+ "\n",
+ "请注意,训练损失会逐周期*下降*,而训练准确率则逐周期*上升*。使用梯度下降优化时,这是预期结果,它应该在每次迭代中最大限度减少所需的数量。\n",
+ "\n",
+ "但是,对验证损失和准确率来说则不然——它们似乎会在经过 20 个周期后达到顶点。这是过拟合的一个例子:模型在训练数据上的表现要好于在之前从未见过的数据上的表现。经过这一点之后,模型会过度优化和学习*特定*于训练数据的表示,但无法*泛化*到测试数据。\n",
+ "\n",
+ "对于这种特殊情况,我们只需在经过 20 个左右的周期后停止训练即可防止过拟合。稍后您将看到如何使用回调自动执行该操作。"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "tf2_text_classification.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/site/zh-cn/io/tutorials/audio.ipynb b/site/zh-cn/io/tutorials/audio.ipynb
index c05655c7b0..c5b7e8ca13 100644
--- a/site/zh-cn/io/tutorials/audio.ipynb
+++ b/site/zh-cn/io/tutorials/audio.ipynb
@@ -48,11 +48,10 @@
"source": [
""
]
},
diff --git a/site/zh-cn/io/tutorials/avro.ipynb b/site/zh-cn/io/tutorials/avro.ipynb
new file mode 100644
index 0000000000..6f6c4a6f3c
--- /dev/null
+++ b/site/zh-cn/io/tutorials/avro.ipynb
@@ -0,0 +1,562 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Tce3stUlHN0L"
+ },
+ "source": [
+ "##### Copyright 2020 The TensorFlow IO Authors."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "tuOe1ymfHZPu"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qFdPvlXBOdUN"
+ },
+ "source": [
+ "# Avro 数据集 API"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MfBg1C5NB3X0"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xHxb-dlhMIzW"
+ },
+ "source": [
+ "## 文本特征向量\n",
+ "\n",
+ "Avro 数据集 API 的目标是将 Avro 格式的数据作为 TensorFlow 数据集原生加载到 TensorFlow 中。Avro 是一个类似于 Protocol Buffers 的数据序列化系统。它广泛用于 Apache Hadoop,可以提供持久数据的序列化格式和 Hadoop 节点之间通信的有线格式。Avro 数据是一种面向行的压缩二进制数据格式。它依赖于存储为单独 JSON 文件的架构。有关 Avro 格式和架构声明的规范,请参阅官方手册。\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MUXex9ctTuDB"
+ },
+ "source": [
+ "## 安装软件包\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "upgCc3gXybsA"
+ },
+ "source": [
+ "### 安装所需的 tensorflow-io 软件包"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "uUDYyMZRfkX4"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install tensorflow-io"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gjrZNJQRJP-U"
+ },
+ "source": [
+ "### 导入软件包"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "m6KXZuTBWgRm"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "import tensorflow_io as tfio\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eCgO11GTJaTj"
+ },
+ "source": [
+ "### 验证 tf 和 tfio 导入"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "dX74RKfZ_TdF"
+ },
+ "outputs": [],
+ "source": [
+ "print(\"tensorflow-io version: {}\".format(tfio.__version__))\n",
+ "print(\"tensorflow version: {}\".format(tf.__version__))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "J0ZKhA6s0Pjp"
+ },
+ "source": [
+ "## 用法"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4CfKVmCvwcL7"
+ },
+ "source": [
+ "### 探索数据集\n",
+ "\n",
+ "为了实现本教程的目的,我们来下载示例 Avro 数据集。\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "IGnbXuVnSo8T"
+ },
+ "source": [
+ "下载示例 Avro 文件:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Tu01THzWcE-J"
+ },
+ "outputs": [],
+ "source": [
+ "!curl -OL https://github.com/tensorflow/io/raw/master/docs/tutorials/avro/train.avro\n",
+ "!ls -l train.avro"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jJzE6lMwhY7l"
+ },
+ "source": [
+ "下载示例 Avro 文件的相应架构文件:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Cpxa6yhLhY7l"
+ },
+ "outputs": [],
+ "source": [
+ "!curl -OL https://github.com/tensorflow/io/raw/master/docs/tutorials/avro/train.avsc\n",
+ "!ls -l train.avsc"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "z9GCyPWNuOm7"
+ },
+ "source": [
+ "在上面的示例中,基于 MNIST 数据集创建了一个测试 Avro 数据集。TFRecord 格式的原始 MNIST 数据集从 TF 命名数据集生成。但是,作为演示数据集,MNIST 数据集过大。为简单起见,我们修剪了大部分内容,只保留前几条记录。此外,对原始 MNIST 数据集中的 `image` 字段进行了额外的修剪,并将其映射到 Avro 中的 `features` 字段。因此,Avro 文件 `train.avro` 有 4 条记录,每条记录有 3 个字段,分别为:`features`(整数的数组)、`label`(整数或 null 的数组)和 `dataType`(枚举)。要查看解码的 `train.avro`(请注意,原始 Avro 数据文件非人类可读,因为 Avro 是压缩格式),请执行以下操作:\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "upgCc3gXybsB"
+ },
+ "source": [
+ "安装读取 Avro 文件所需的包:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nS3eTBvjt-O4"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install avro\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "m7XR0agdhY7n"
+ },
+ "source": [
+ "要以人类可读的格式读取和打印 Avro 文件,请运行以下代码:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nS3eTBvjt-O5"
+ },
+ "outputs": [],
+ "source": [
+ "from avro.io import DatumReader\n",
+ "from avro.datafile import DataFileReader\n",
+ "\n",
+ "import json\n",
+ "\n",
+ "def print_avro(avro_file, max_record_num=None):\n",
+ " if max_record_num is not None and max_record_num <= 0:\n",
+ " return\n",
+ "\n",
+ " with open(avro_file, 'rb') as avro_handler:\n",
+ " reader = DataFileReader(avro_handler, DatumReader())\n",
+ " record_count = 0\n",
+ " for record in reader:\n",
+ " record_count = record_count+1\n",
+ " print(record)\n",
+ " if max_record_num is not None and record_count == max_record_num:\n",
+ " break\n",
+ "\n",
+ "print_avro(avro_file='train.avro')\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qKgUPm6JhY7n"
+ },
+ "source": [
+ "由 `train.avsc` 表示的 `train.avro` 的架构是一个 JSON 格式的文件。查看`train.avsc`:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "D-95aom1hY7o"
+ },
+ "outputs": [],
+ "source": [
+ "def print_schema(avro_schema_file):\n",
+ " with open(avro_schema_file, 'r') as handle:\n",
+ " parsed = json.load(handle)\n",
+ " print(json.dumps(parsed, indent=4, sort_keys=True))\n",
+ "\n",
+ "print_schema('train.avsc')\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "21szKFY1hY7o"
+ },
+ "source": [
+ "### 准备数据集\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hNeBO9m-hY7o"
+ },
+ "source": [
+ "使用 Avro 数据集 API 将 `train.avro` 加载为 TensorFlow 数据集:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "v-nbLZHKhY7o"
+ },
+ "outputs": [],
+ "source": [
+ "features = {\n",
+ " 'features[*]': tfio.experimental.columnar.VarLenFeatureWithRank(dtype=tf.int32),\n",
+ " 'label': tf.io.FixedLenFeature(shape=[], dtype=tf.int32, default_value=-100),\n",
+ " 'dataType': tf.io.FixedLenFeature(shape=[], dtype=tf.string)\n",
+ "}\n",
+ "\n",
+ "schema = tf.io.gfile.GFile('train.avsc').read()\n",
+ "\n",
+ "dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],\n",
+ " reader_schema=schema,\n",
+ " features=features,\n",
+ " shuffle=False,\n",
+ " batch_size=3,\n",
+ " num_epochs=1)\n",
+ "\n",
+ "for record in dataset:\n",
+ " print(record['features[*]'])\n",
+ " print(record['label'])\n",
+ " print(record['dataType'])\n",
+ " print(\"--------------------\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "IF_kYz_o2DH4"
+ },
+ "source": [
+ "上面的示例将 `train.avro` 转换为 TensorFlow 数据集。数据集的每个元素都是一个字典,其关键字为特征名称,值为转换后的稀疏或密集张量。例如,它会将 `features`、`label`、`dataType` 字段分别转换为 VarLenFeature(SparseTensor)、FixedLenFeature(DenseTensor) 和 FixLenFeature(DenseTensor)。由于 batch_size 为 3,它会将 `train.avro` 中的 3 条记录强制转换为结果数据集中的一个元素。对于 `train.avro` 中标签为 null 的第一条记录,Avro 读取器会将其替换为指定的默认值 (-100)。在本例中,`train.avro` 中总共有 4 条记录。由于批次大小为 3,结果数据集包含 3 个元素,最后一个元素的批次大小为 1。但是,如果大小小于批次大小,用户也可以通过启用 `drop_final_batch` 丢弃最后一个批次。例如:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "bc9vDHyghY7p"
+ },
+ "outputs": [],
+ "source": [
+ "dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],\n",
+ " reader_schema=schema,\n",
+ " features=features,\n",
+ " shuffle=False,\n",
+ " batch_size=3,\n",
+ " drop_final_batch=True,\n",
+ " num_epochs=1)\n",
+ "\n",
+ "for record in dataset:\n",
+ " print(record)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "x45KolnDhY7p"
+ },
+ "source": [
+ "此外,还可以增加 num_parallel_reads 以通过提高 Avro 解析/读取并行性来加速 Avro 数据处理。\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Z2x-gPj_hY7p"
+ },
+ "outputs": [],
+ "source": [
+ "dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],\n",
+ " reader_schema=schema,\n",
+ " features=features,\n",
+ " shuffle=False,\n",
+ " num_parallel_reads=16,\n",
+ " batch_size=3,\n",
+ " drop_final_batch=True,\n",
+ " num_epochs=1)\n",
+ "\n",
+ "for record in dataset:\n",
+ " print(record)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6V-nwDJGhY7p"
+ },
+ "source": [
+ "有关 `make_avro_record_dataset` 的详细用法,请参阅 API 文档。\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vIOijGlAhY7p"
+ },
+ "source": [
+ "### 使用 Avro 数据集训练 tf.keras 模型\n",
+ "\n",
+ "现在,我们来看一个端到端示例,该示例基于 MNIST 数据集使用 Avro 数据集来训练 tf.keras 模型。\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "s7K85D53hY7q"
+ },
+ "source": [
+ "使用 Avro 数据集 API 将 `train.avro` 加载为 TensorFlow 数据集:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "VFoeLwIOhY7q"
+ },
+ "outputs": [],
+ "source": [
+ "features = {\n",
+ " 'features[*]': tfio.experimental.columnar.VarLenFeatureWithRank(dtype=tf.int32),\n",
+ " 'label': tf.io.FixedLenFeature(shape=[], dtype=tf.int32, default_value=-100),\n",
+ "}\n",
+ "\n",
+ "\n",
+ "schema = tf.io.gfile.GFile('train.avsc').read()\n",
+ "\n",
+ "dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],\n",
+ " reader_schema=schema,\n",
+ " features=features,\n",
+ " shuffle=False,\n",
+ " batch_size=1,\n",
+ " num_epochs=1)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hR2FnIIMhY7q"
+ },
+ "source": [
+ "定义一个简单的 Keras 模型:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "hGV5rHfJhY7q"
+ },
+ "outputs": [],
+ "source": [
+ "def build_and_compile_cnn_model():\n",
+ " model = tf.keras.Sequential()\n",
+ " model.compile(optimizer='sgd', loss='mse')\n",
+ " return model\n",
+ "\n",
+ "model = build_and_compile_cnn_model()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Tuv9n6HshY7q"
+ },
+ "source": [
+ "### 使用 Avro 数据集训练 Keras 模型:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "lb44cUuWhY7r"
+ },
+ "outputs": [],
+ "source": [
+ "def extract_label(feature):\n",
+ " label = feature.pop('label')\n",
+ " return tf.sparse.to_dense(feature['features[*]']), label\n",
+ "\n",
+ "model.fit(x=dataset.map(extract_label), epochs=1, steps_per_epoch=1, verbose=1)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7K6qAv5rhY7r"
+ },
+ "source": [
+ "Avro 数据集可以解析任何 Avro 数据并将其强制转换为 TensorFlow 张量,包括记录、映射、数组、分支和枚举中的记录。解析信息作为映射传递到 Avro 数据集实现中,其中关键字用于编码如何解析数据,值用于编码如何将数据强制转换为 TensorFlow 张量 – 决定基元类型(例如 bool、int、long、float、double、string)以及张量类型(例如稀疏或密集)。下面提供了 TensorFlow 解析器类型(见表 1)和基元类型强制转换(表 2)的清单。\n",
+ "\n",
+ "表 1 支持的 TensorFlow 解析器类型:\n",
+ "\n",
+ "TensorFlow 解析器类型 | TensorFlow 张量 | 解释\n",
+ "--- | --- | ---\n",
+ "tf.FixedLenFeature([], tf.int32) | 密集张量 | 解析固定长度的特征;也就是说,所有行都具有相同的恒定数量元素,例如,只有一个元素或每行始终具有相同数量元素的数组\n",
+ "tf.SparseFeature(index_key=['key_1st_index', 'key_2nd_index'], value_key='key_value', dtype=tf.int64, size=[20, 50]) | 稀疏张量 | 解析稀疏特征,其中每行都有一个可变长度的索引和值清单。'index_key' 标识索引。'value_key' 标识值。'dtype' 为数据类型。'size' 为每个索引条目的预期最大索引值\n",
+ "tfio.experimental.columnar.VarLenFeatureWithRank([],tf.int64) | 稀疏张量 | 解析可变长度特征;这意味着每个数据行可以具有可变数量的元素,例如,第一行有 5 个元素,第二行有 7 个元素\n",
+ "\n",
+ "表 2 支持的 Avro 类型到 TensorFlow 类型的转换:\n",
+ "\n",
+ "Avro 基元类型 | TensorFlow 基元类型\n",
+ "--- | ---\n",
+ "bool:二进制值 | tf.bool\n",
+ "byte:8 位无符号字节序列 | tf.string\n",
+ "double:双精度 64 位 IEEE 浮点数 | tf.float64\n",
+ "enum:枚举类型 | 使用符号名称的 tf.string\n",
+ "float:单精度 32 位 IEEE 浮点数 | tf.float32\n",
+ "int:32 位有符号整数 | tf.int32\n",
+ "long:64 位有符号整数 | tf.int64\n",
+ "null:没有值 | 使用默认值\n",
+ "string:unicode 字符序列 | tf.string\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1PFQPuy5hY7r"
+ },
+ "source": [
+ "测试中提供了一组全面的 Avro 数据集 API 示例。\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [
+ "Tce3stUlHN0L"
+ ],
+ "name": "avro.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/site/zh-cn/io/tutorials/kafka.ipynb b/site/zh-cn/io/tutorials/kafka.ipynb
index a4889041b9..88db9a54e1 100644
--- a/site/zh-cn/io/tutorials/kafka.ipynb
+++ b/site/zh-cn/io/tutorials/kafka.ipynb
@@ -47,13 +47,10 @@
},
"source": [
""
]
},
diff --git a/site/zh-cn/lite/android/lite_build.md b/site/zh-cn/lite/android/lite_build.md
index 2b1929feec..7ebda6f09e 100644
--- a/site/zh-cn/lite/android/lite_build.md
+++ b/site/zh-cn/lite/android/lite_build.md
@@ -30,7 +30,7 @@ allprojects {
-{% dynamic if 'tflite-android-tos' in user.acknowledged_walls and request.tld != 'cn' %} 您可以在此处下载 Docker 文件 {% dynamic else %} 您必须确认服务条款才能下载此文件。确认 {% dynamic endif %}
+{% dynamic if 'tflite-android-tos' in user.acknowledged_walls and request.tld != 'cn' %} 您可以在此处下载 Docker 文件 {% dynamic else %} 您必须确认服务条款才能下载此文件。确认 {% dynamic endif %}
@@ -70,7 +70,7 @@ sdkmanager \
Bazel 是适用于 TensorFlow 的主要构建系统。要使用 Bazel 构建,您必须在系统上安装此工具以及 Android NDK 与 SDK。
1. 安装最新版本的 [Bazel 构建系统](https://bazel.build/versions/master/docs/install.html)。
-2. 需要 Android NDK 才能构建原生 (C/C++) TensorFlow Lite 代码。最新的推荐版本是 17c,在[此处](https://developer.android.com/ndk/downloads/older_releases.html#ndk-19c-downloads)可以找到该版本。
+2. 需要 Android NDK 才能构建原生 (C/C++) TensorFlow Lite 代码。最新的推荐版本是 21e,在[此处](https://developer.android.com/ndk/downloads/older_releases.html#ndk-21e-downloads)可以找到该版本。
3. 在[此处](https://developer.android.com/tools/revisions/build-tools.html)可以获取 Android SDK 和构建工具,或者,您也可以通过 [Android Studio](https://developer.android.com/studio/index.html) 获取。对于 TensorFlow Lite 模型构建,推荐的构建工具 API 版本是 23 或更高版本。
### 配置工作区和 .bazelrc
@@ -85,10 +85,10 @@ Bazel 是适用于 TensorFlow 的主要构建系统。要使用 Bazel 构建,
如果不设置这些变量,则必须在脚本提示中以交互方式提供。如果配置成功,则会在根文件夹的 `.tf_configure.bazelrc` 文件中产生类似以下代码的条目:
```shell
-build --action_env ANDROID_NDK_HOME="/usr/local/android/android-ndk-r19c"
-build --action_env ANDROID_NDK_API_LEVEL="21"
-build --action_env ANDROID_BUILD_TOOLS_VERSION="28.0.3"
-build --action_env ANDROID_SDK_API_LEVEL="23"
+build --action_env ANDROID_NDK_HOME="/usr/local/android/android-ndk-r21e"
+build --action_env ANDROID_NDK_API_LEVEL="26"
+build --action_env ANDROID_BUILD_TOOLS_VERSION="30.0.3"
+build --action_env ANDROID_SDK_API_LEVEL="30"
build --action_env ANDROID_SDK_HOME="/usr/local/android/android-sdk-linux"
```
@@ -99,6 +99,8 @@ build --action_env ANDROID_SDK_HOME="/usr/local/android/android-sdk-linux"
```sh
bazel build -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
+ --define=android_dexmerger_tool=d8_dexmerger \
+ --define=android_incremental_dexing_tool=d8_dexbuilder \
//tensorflow/lite/java:tensorflow-lite
```
diff --git a/site/zh-cn/lite/android/quickstart.md b/site/zh-cn/lite/android/quickstart.md
index b4f786901f..dc9f7daa58 100644
--- a/site/zh-cn/lite/android/quickstart.md
+++ b/site/zh-cn/lite/android/quickstart.md
@@ -23,12 +23,12 @@
1. 克隆 Git 仓库:
git clone https://github.com/tensorflow/examples.git
-
+
2. 将您的 Git 实例配置为使用稀疏签出,这样您就只有目标检测示例应用的文件:
cd examples
- git sparse-checkout init --cone
- git sparse-checkout set lite/examples/object_detection/android_play_services
-
+ git sparse-checkout init --cone
+ git sparse-checkout set lite/examples/object_detection/android_play_services
+
### 导入并运行项目
diff --git a/site/zh-cn/lite/examples/on_device_training/overview.ipynb b/site/zh-cn/lite/examples/on_device_training/overview.ipynb
index 3ec0423740..b54bfb767a 100644
--- a/site/zh-cn/lite/examples/on_device_training/overview.ipynb
+++ b/site/zh-cn/lite/examples/on_device_training/overview.ipynb
@@ -128,7 +128,7 @@
" <figcaption><b>Figure 1</b>: <a href=\"https://github.com/zalandoresearch/fashion-mnist\">Fashion-MNIST samples</a> (by Zalando, MIT License).</figcaption>\n",
"</figure>\n",
"\n",
- "You can explore this dataset in more depth in the [Keras classification tutorial](https://tensorflow.google.cn/tutorials/keras/classification#import_the_fashion_mnist_dataset)."
+ "您可以在 [Keras 分类教程中](https://tensorflow.google.cn/tutorials/keras/classification#import_the_fashion_mnist_dataset)更深入探索此数据集。"
]
},
{
diff --git a/site/zh-cn/lite/examples/pose_estimation/overview.md b/site/zh-cn/lite/examples/pose_estimation/overview.md
index 46aad17abc..b81a036631 100644
--- a/site/zh-cn/lite/examples/pose_estimation/overview.md
+++ b/site/zh-cn/lite/examples/pose_estimation/overview.md
@@ -7,7 +7,7 @@
## 开始使用
-下载此模块
+如果您是 TensorFlow Lite 新用户,并且使用的是 Android 或 iOS,我们建议您研究以下可以帮助您入门的示例应用。
Android 示例 iOS 示例
@@ -21,7 +21,7 @@
### 使用案例
-为了达到清晰的目的,该算法只是对图像中的人简单的预测身体关键位置所在,而不会去辨别此人是谁。
+姿态预测是指检测图像和视频中人物的计算机视觉技术,以便确定某人的身体部位(如肘部)出现在图像中的位置。务必了解这样一个事实:姿态预测仅能预估关键身体关节的位置,而无法识别图像或视频中的人物。
姿态预测模型会将处理后的相机图像作为输入,并输出有关关键点的信息。检测到的关键点由部位 ID 索引,置信度分数介于 0.0 和 1.0 之间。置信度分数表示该位置存在关键点的概率。
diff --git a/site/zh-cn/model_optimization/guide/quantization/training_comprehensive_guide.ipynb b/site/zh-cn/model_optimization/guide/quantization/training_comprehensive_guide.ipynb
index 2a9e251f64..d24e09048a 100644
--- a/site/zh-cn/model_optimization/guide/quantization/training_comprehensive_guide.ipynb
+++ b/site/zh-cn/model_optimization/guide/quantization/training_comprehensive_guide.ipynb
@@ -47,10 +47,13 @@
},
"source": [
""
]
},
@@ -62,7 +65,7 @@
"source": [
"欢迎阅读 Keras 量化感知训练的综合指南。\n",
"\n",
- "本页面记录了各种用例,并展示了如何将 API 用于每种用例。了解需要哪些 API 后,可在 [API 文档](https://tensorflow.google.cn/model_optimization/api_docs/python/tfmot/quantization)中找到参数和底层详细信息:\n",
+ "本页面记录了各种用例,并展示了如何将 API 用于每种用例。了解需要哪些 API 后,可在 [API 文档](https://tensorflow.google.cn/model_optimization/api_docs/python/tfmot/quantization)中找到参数和底层详细信息。\n",
"\n",
"- 如果要查看量化感知训练的好处以及支持的功能,请参阅[概述](https://tensorflow.google.cn/model_optimization/guide/quantization/training.md)。\n",
"- 有关单个端到端示例,请参阅[量化感知训练示例](https://tensorflow.google.cn/model_optimization/guide/quantization/training_example.md)。\n",
@@ -105,8 +108,7 @@
},
"outputs": [],
"source": [
- "! pip uninstall -y tensorflow\n",
- "! pip install -q tf-nightly\n",
+ "! pip install -q tensorflow\n",
"! pip install -q tensorflow-model-optimization\n",
"\n",
"import tensorflow as tf\n",
diff --git a/site/zh-cn/probability/examples/Distributed_Inference_with_JAX.ipynb b/site/zh-cn/probability/examples/Distributed_Inference_with_JAX.ipynb
index 3eb579fc80..6f29749668 100644
--- a/site/zh-cn/probability/examples/Distributed_Inference_with_JAX.ipynb
+++ b/site/zh-cn/probability/examples/Distributed_Inference_with_JAX.ipynb
@@ -43,10 +43,8 @@
"\n",
""
]
diff --git a/site/zh-cn/probability/examples/FFJORD_Demo.ipynb b/site/zh-cn/probability/examples/FFJORD_Demo.ipynb
index 2e15273427..1e8e84b204 100644
--- a/site/zh-cn/probability/examples/FFJORD_Demo.ipynb
+++ b/site/zh-cn/probability/examples/FFJORD_Demo.ipynb
@@ -45,8 +45,7 @@
" 在 TensorFlow.org 上查看 | \n",
" 在 Google Colab 中运行 | \n",
" 在 Github 上查看源代码 | \n",
- " 下载笔记本\n",
- " | \n",
+ " 下载笔记本 | \n",
""
]
},
@@ -186,18 +185,23 @@
"\n",
"为了建立这种联系,我们需要进行以下操作:\n",
"\n",
- "1. 在定义**基础分布**的空间 $\\mathcal{Y}$ 与数据域的空间 $\\mathcal{X}$ 之间定义一个双射映射 $\\mathcal{T}*{\\theta}:\\mathbf{x} \\rightarrow \\mathbf{y}$, $\\mathcal{T}*{\\theta}^{1}:\\mathbf{y} \\rightarrow \\mathbf{x}$。\n",
+ "1. 在定义基础分布的空间 $\\mathcal{Y}$ 与数据域的空间 $\\mathcal{X}$ 之间定义一个双射映射 $\\mathcal{T}{em1}{\\theta}:\\mathbf{x} \\rightarrow \\mathbf{y}$, $\\mathcal{T}{/em1}{\\theta}^{1}:\\mathbf{y} \\rightarrow \\mathbf{x}$。\n",
"2. 有效地跟踪我们执行的将概率概念转移到 $\\mathcal{X}$ 上的变形。\n",
"\n",
"在 $\\mathcal{X}$ 上定义的概率分布的以下表达式中对第二个条件进行了形式化:\n",
"\n",
+ "```\n",
"$$ \\log p_{\\mathbf{x}}(\\mathbf{x})=\\log p_{\\mathbf{y}}(\\mathbf{y})-\\log \\operatorname{det}\\left|\\frac{\\partial \\mathcal{T}_{\\theta}(\\mathbf{y})}{\\partial \\mathbf{y}}\\right| $$\n",
+ "```\n",
"\n",
"FFJORD 双射器通过定义以下转换来实现这一点:$$ \\mathcal{T_{\\theta}}: \\mathbf{x} = \\mathbf{z}(t_{0}) \\rightarrow \\mathbf{y} = \\mathbf{z}(t_{1}) \\quad : \\quad \\frac{d \\mathbf{z}}{dt} = \\mathbf{f}(t, \\mathbf{z}, \\theta) $$\n",
"\n",
"只要描述状态 $\\mathbf{z}$ 演化的函数 $\\mathbf{f}$ 表现良好,并且可以通过集成以下表达式来计算 `log_det_jacobian`,则此转换可逆。\n",
"\n",
- "$$ \\log \\operatorname{det}\\left|\\frac{\\partial \\mathcal{T}*{\\theta}(\\mathbf{y})}{\\partial \\mathbf{y}}\\right| = -\\int*{t_{0}}^{t_{1}} \\operatorname{Tr}\\left(\\frac{\\partial \\mathbf{f}(t, \\mathbf{z}, \\theta)}{\\partial \\mathbf{z}(t)}\\right) d t $$\n",
+ "$$\n",
+ "\\log \\operatorname{det}\\left|\\frac{\\partial \\mathcal{T}_{\\theta}(\\mathbf{y})}{\\partial \\mathbf{y}}\\right| = \n",
+ "-\\int_{t_{0}}^{t_{1}} \\operatorname{Tr}\\left(\\frac{\\partial \\mathbf{f}(t, \\mathbf{z}, \\theta)}{\\partial \\mathbf{z}(t)}\\right) d t\n",
+ "$$\n",
"\n",
"在此演示中,我们将训练 FFJORD 双射器,将高斯分布扭曲到 `moons` 数据集定义的分布上。这将分 3 个步骤完成:\n",
"\n",
diff --git a/site/zh-cn/probability/examples/Gaussian_Copula.ipynb b/site/zh-cn/probability/examples/Gaussian_Copula.ipynb
index 4a443717d4..bb12520204 100644
--- a/site/zh-cn/probability/examples/Gaussian_Copula.ipynb
+++ b/site/zh-cn/probability/examples/Gaussian_Copula.ipynb
@@ -42,9 +42,10 @@
"# Copula 入门\n",
"\n",
""
]
@@ -170,7 +171,10 @@
"\n",
"我们从下面的模型开始:\n",
"\n",
- "$$\\begin{align*} X &\\sim \\text{Kumaraswamy}(a, b) \\ Y &\\sim \\text{Gumbel}(\\mu, \\beta) \\end{align*}$$\n",
+ "$$\\begin{align*}\n",
+ "X &\\sim \\text{Kumaraswamy}(a, b) \\\\\n",
+ "Y &\\sim \\text{Gumbel}(\\mu, \\beta)\n",
+ "\\end{align*}$$\n",
"\n",
"使用 Copula 获得一个双变量 R.V. $Z$,其边缘为 [Kumaraswamy](https://en.wikipedia.org/wiki/Kumaraswamy_distribution) 和 [Gumbel](https://en.wikipedia.org/wiki/Gumbel_distribution)。\n",
"\n",
@@ -285,9 +289,9 @@
"id": "HyIufLCQ2PIc"
},
"source": [
- "最后,我们来实际使用这个高斯 Copula。我们将使用 $\\begin{bmatrix}1 & 0\\rho & \\sqrt{(1-\\rho^2)}\\end{bmatrix}$ 的Cholesky,它将对应于方差 1,以及多元正态分布的相关性 $\\rho$。\n",
+ "最后,我们来实际使用这个高斯 Copula。我们将使用 $\\begin{bmatrix}1 & 0\\\\rho & \\sqrt{(1-\\rho^2)}\\end{bmatrix}$ 的Cholesky,它将对应于方差 1,以及多元正态分布的相关性 $\\rho$。\n",
"\n",
- "我们来看几种情况: "
+ "我们来看几个案例: "
]
},
{
diff --git a/site/zh-cn/probability/examples/Gaussian_Process_Latent_Variable_Model.ipynb b/site/zh-cn/probability/examples/Gaussian_Process_Latent_Variable_Model.ipynb
index 00b24119a4..3a58faee94 100644
--- a/site/zh-cn/probability/examples/Gaussian_Process_Latent_Variable_Model.ipynb
+++ b/site/zh-cn/probability/examples/Gaussian_Process_Latent_Variable_Model.ipynb
@@ -45,8 +45,7 @@
" 在 TensorFlow.org 上查看 | \n",
" 在 Google Colab 中运行 | \n",
" 在 GitHub 上查看源代码 | \n",
- " 下载笔记本\n",
- " | \n",
+ " 下载笔记本 | \n",
""
]
},
@@ -64,7 +63,7 @@
"\n",
"我们使用所谓的*索引集*来标记 GP 所组成的集合中的每一个随机变量。在有限索引集的情况下,我们只会得到一个多元正态分布。然而,当我们考虑*有限*集合时,GP 最有趣。对于类似 $\\mathbb{R}^D$ 的索引集(其中,我们对*$D$ 维空间中的每个点*都有一个随机变量),可以将 GP 视为随机*函数*上的分布。如果可以实现,从这样的 GP 中进行单次抽样,将为 $\\mathbb{R}^D$ 中的每个点分配一个(联合正态分布的)值。在此 Colab 中,我们将关注一些 $\\mathbb{R}^D$ 上的 GP。\n",
"\n",
- "正态分布完全由其一阶和二阶统计量确定,实际上,定义正态分布的一种方式是使其高阶累积量全部为零。GP 也是如此:我们可以通过描述均值和协方差*来完全指定 GP。回想一下,对于有限维多元正态分布,均值是一个向量,协方差是一个方形的对称正定矩阵。在无限维 GP 中,这些结构会泛化为均值*函数* $m : \\mathbb{R}^D \\to \\mathbb{R}$(在索引集的每个点上定义),以及协方差“*内核*”函数,$k : \\mathbb{R}^D \\times \\mathbb{R}^D \\to \\mathbb{R}$。内核函数必须为[正定](https://en.wikipedia.org/wiki/Positive-definite_function),本质上是说,在有限点集合的限制下,它会产生一个正定矩阵。\n",
+ "正态分布完全由其一阶和二阶统计量确定,实际上,定义正态分布的一种方式是使其高阶累积量全部为零。GP 也是如此:我们可以通过描述均值和协方差**来完全指定 GP。回想一下,对于有限维多元正态分布,均值是一个向量,协方差是一个方形的对称正定矩阵。在无限维 GP 中,这些结构会泛化为均值函数* $m : \\mathbb{R}^D \\to \\mathbb{R}$(在索引集的每个点上定义),以及协方差“*内核*”函数,$k : \\mathbb{R}^D \\times \\mathbb{R}^D \\to \\mathbb{R}$。内核函数必须为[正定](https://en.wikipedia.org/wiki/Positive-definite_function),本质上是说,在有限点集合的限制下,它会产生一个正定矩阵。\n",
"\n",
"GP 的大部分结构都衍生自其协方差内核函数,此函数描述了采样函数的值如何在邻近(或不那么邻近)点之间变化。不同的协方差函数会鼓励不同程度的平滑度。一种常用的内核函数是“指数二次函数”(又称“高斯函数”、“平方指数函数”或“径向基函数”),$k(x, x') = \\sigma^2 e^{(x - x^2) / \\lambda^2}$。其他示例在 David Duvenaud 的 [Kernel Cookbook](https://www.cs.toronto.edu/~duvenaud/cookbook/) 页面和典籍 [Gaussian Processes for Machine Learning](http://www.gaussianprocess.org/gpml/) 中均有概述。"
]
@@ -86,7 +85,7 @@
"source": [
"## 应用 GP:回归和隐变量模型\n",
"\n",
- "使用 GP 的一种方式是回归:给定一组观测数据,其形式为输入 ${x_i}*{i=1}^N$(索引集的元素)和观测值 ${y_i}*{i=1}^N$,我们可以使用这些数据在新的点集 ${x_j^*}_{j=1}^M$ 处形成一个后验预测分布。由于分布都是高斯分布,因此可以归结为一些简单的线性代数(但请注意:必要的计算在数据点数量上具有运行时*立方*,并且在数据点数量上需要空间二次,这是使用 GP 的一个主要限制因素,目前的很多研究都集中在精确后验推断的计算上可行的替代方式上)。我们在 [TFP Colab 中的 GP 回归](https://colab.research.google.com/github/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Gaussian_Process_Regression_In_TFP.ipynb)中更详细地介绍了 GP 回归。\n",
+ "使用 GP 的一种方式是回归:给定一组观测数据,其形式为输入 ${x_i}{em0}{i=1}^N$(索引集的元素)和观测值 ${y_i}{/em0}{i=1}^N$,我们可以使用这些数据在新的点集 ${x_j^*}_{j=1}^M$ 处形成一个后验预测分布。由于分布都是高斯分布,因此可以归结为一些简单的线性代数(但请注意:必要的计算在数据点数量上具有运行时立方,并且在数据点数量上需要空间二次,这是使用 GP 的一个主要限制因素,目前的很多研究都集中在精确后验推断的计算上可行的替代方式上)。我们在 TFP Colab 中的 GP 回归中更详细地介绍了 GP 回归。\n",
"\n",
"使用 GP 的另一种方式是作为隐变量模型:给定高维观测值(例如,图像)的集合,我们可以设想某种低维隐结构。我们假设,在隐结构的条件下,大量的输出(图像中的像素)彼此独立。此模型中的训练包括:\n",
"\n",
diff --git a/site/zh-cn/probability/examples/Multiple_changepoint_detection_and_Bayesian_model_selection.ipynb b/site/zh-cn/probability/examples/Multiple_changepoint_detection_and_Bayesian_model_selection.ipynb
index c69f30c7cb..d9ed6ca525 100644
--- a/site/zh-cn/probability/examples/Multiple_changepoint_detection_and_Bayesian_model_selection.ipynb
+++ b/site/zh-cn/probability/examples/Multiple_changepoint_detection_and_Bayesian_model_selection.ipynb
@@ -172,11 +172,17 @@
"\n",
"我们将此问题建模为切换(非均质)泊松过程:在每个时间点,发生的事件数按泊松分布,事件*比率*取决于未观察到的系统状态 $z_t$:\n",
"\n",
+ "```\n",
"$$x_t \\sim \\text{Poisson}(\\lambda_{z_t})$$\n",
+ "```\n",
"\n",
"隐状态为离散值:$z_t \\in {1, 2, 3, 4}$,因此,$\\lambda = [\\lambda_1, \\lambda_2, \\lambda_3, \\lambda_4]$ 是包含每种状态的泊松比的简单向量。为了对状态随时间的演变进行建模,我们定义一个简单的转移模型 $p(z_t | z_{t-1})$:我们假定在每一步保持前一种状态的概率为 $p$,那么我们随机均匀转移到其他状态的概率为 $1-p$。初始状态也是随机均匀进行选择,因此,我们使用:\n",
"\n",
- "$$ \\begin{align*} z_1 &\\sim \\text{Categorical}\\left(\\left{\\frac{1}{4}, \\frac{1}{4}, \\frac{1}{4}, \\frac{1}{4}\\right}\\right)\\ z_t | z_{t-1} &\\sim \\text{Categorical}\\left(\\left{\\begin{array}{cc}p & \\text{if } z_t = z_{t-1} \\ \\frac{1-p}{4-1} & \\text{otherwise}\\end{array}\\right}\\right) \\end{align*}$$\n",
+ "$$\n",
+ "\\begin{align*}\n",
+ "z_1 &\\sim \\text{Categorical}\\left(\\left\\{\\frac{1}{4}, \\frac{1}{4}, \\frac{1}{4}, \\frac{1}{4}\\right\\}\\right)\\\\\n",
+ "z_t | z_{t-1} &\\sim \\text{Categorical}\\left(\\left\\{\\begin{array}{cc}p & \\text{if } z_t = z_{t-1} \\\\ \\frac{1-p}{4-1} & \\text{otherwise}\\end{array}\\right\\}\\right)\n",
+ "\\end{align*}$$\n",
"\n",
"上述假设对应于具有泊松发射的[隐马尔可夫模型](http://mlg.eng.cam.ac.uk/zoubin/papers/ijprai.pdf)。我们可以使用 `tfd.HiddenMarkovModel` 在 TFP 中对这些假设进行编码。首先,我们在初始状态中定义转移矩阵和均匀先验:"
]
@@ -508,7 +514,9 @@
"\n",
"遗憾的是,真实边缘似然对离散状态 $z_{1:T}$ 和比率参数(的向量) $\\lambda$, $$p(x_{1:T}) = \\int p(x_{1:T}, z_{1:T}, \\lambda) dz d\\lambda,$$ 同时积分,对于此模型不易处理。为方便起见,我们将使用所谓的“[经验贝叶斯](https://www.cs.ubc.ca/~schmidtm/Courses/540-W16/L19.pdf)”或“第二类极大似然”估计来逼近它:我们将优化与每种系统状态关联的(未知)比率参数 $\\lambda$ 的值,而不是完全积分这些参数:\n",
"\n",
+ "```\n",
"$$\\tilde{p}(x_{1:T}) = \\max_\\lambda \\int p(x_{1:T}, z_{1:T}, \\lambda) dz$$\n",
+ "```\n",
"\n",
"这种逼近可能会过度拟合,即,它将偏向于比真实边缘似然所需模型更复杂的模型。我们可以考虑较为准确可靠的逼近,例如,优化变化的下限,或者使用蒙特卡洛估计器,例如,[退火重要性采样](https://tensorflow.google.cn/probability/api_docs/python/tfp/mcmc/sample_annealed_importance_chain);(可惜)这些内容超出了本笔记本讨论的范畴。(有关贝叶斯模型选择和逼近的详细信息,请参阅[机器学习:概率视角](https://www.cs.ubc.ca/~murphyk/MLbook/)的第 7 章,该部分内容非常精彩,可以为您提供不错的参考。)\n",
"\n",
diff --git a/site/zh-cn/probability/examples/Probabilistic_Layers_VAE.ipynb b/site/zh-cn/probability/examples/Probabilistic_Layers_VAE.ipynb
index 7686b5ff80..fa56f67e5b 100644
--- a/site/zh-cn/probability/examples/Probabilistic_Layers_VAE.ipynb
+++ b/site/zh-cn/probability/examples/Probabilistic_Layers_VAE.ipynb
@@ -452,7 +452,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Originals:\n"
+ "Decoded Modes:\n"
]
},
{
@@ -472,7 +472,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Decoded Modes:\n"
+ "Decoded Means:\n"
]
},
{
@@ -588,7 +588,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Randomly Generated Means:\n"
+ "Randomly Generated Modes:\n"
]
},
{
diff --git a/site/zh-cn/probability/examples/TFP_Release_Notebook_0_11_0.ipynb b/site/zh-cn/probability/examples/TFP_Release_Notebook_0_11_0.ipynb
index b44fe7bb83..8aff319958 100644
--- a/site/zh-cn/probability/examples/TFP_Release_Notebook_0_11_0.ipynb
+++ b/site/zh-cn/probability/examples/TFP_Release_Notebook_0_11_0.ipynb
@@ -109,16 +109,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:125: UserWarning: No GPU/TPU found, falling back to CPU.\n",
- " warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
+ "jit vmap sample: [ 2.17746 2.6618252 3.427014 -0.80979496 5.87146 4.2002716\n",
+ " 1.2994273 1.2281269 3.5244293 4.1996603 ]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "jit vmap sample: [ 2.17746 2.6618252 3.427014 -0.80979496 5.87146 4.2002716\n",
- " 1.2994273 1.2281269 3.5244293 4.1996603 ]\n"
+ "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:125: UserWarning: No GPU/TPU found, falling back to CPU.\n",
+ " warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
]
},
{
@@ -1310,10 +1310,6 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "(, ) \n",
- " (scalar sample)\n",
"WARNING:tensorflow:Note that RandomUniformInt inside pfor op may not give same output as inside a sequential loop.\n"
]
},
@@ -1321,7 +1317,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "WARNING:tensorflow:Note that RandomUniformInt inside pfor op may not give same output as inside a sequential loop.\n"
+ "WARNING:tensorflow:Note that RandomStandardNormal inside pfor op may not give same output as inside a sequential loop.\n"
]
},
{
@@ -1335,17 +1331,17 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "WARNING:tensorflow:Note that RandomStandardNormal inside pfor op may not give same output as inside a sequential loop.\n"
+ "(, ) \n",
+ " (scalar sample)\n",
+ "WARNING:tensorflow:Note that RandomUniformInt inside pfor op may not give same output as inside a sequential loop.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "(, ) \n",
- " ((1,) sample)\n",
"WARNING:tensorflow:Note that RandomUniformInt inside pfor op may not give same output as inside a sequential loop.\n"
]
},
@@ -1353,19 +1349,19 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "WARNING:tensorflow:Note that RandomUniformInt inside pfor op may not give same output as inside a sequential loop.\n"
+ "(, ) \n",
+ " ((3,) sample)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
- "text": [
- "WARNING:tensorflow:Note that RandomStandardNormal inside pfor op may not give same output as inside a sequential loop.\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
"text": [
"(, , ) \n",
+ " ((1,) sample)\n",
+ "WARNING:tensorflow:Note that RandomUniformInt inside pfor op may not give same output as inside a sequential loop.\n"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
diff --git a/site/zh-cn/quantum/tutorials/mnist.ipynb b/site/zh-cn/quantum/tutorials/mnist.ipynb
index c7e3167698..20ecfbda50 100644
--- a/site/zh-cn/quantum/tutorials/mnist.ipynb
+++ b/site/zh-cn/quantum/tutorials/mnist.ipynb
@@ -977,8 +977,8 @@
"cnn_accuracy = cnn_results[1]\n",
"fair_nn_accuracy = fair_nn_results[1]\n",
"\n",
- "sns.barplot([\"Quantum\", \"Classical, full\", \"Classical, fair\"],\n",
- " [qnn_accuracy, cnn_accuracy, fair_nn_accuracy])"
+ "sns.barplot(x=[\"Quantum\", \"Classical, full\", \"Classical, fair\"],\n",
+ " y=[qnn_accuracy, cnn_accuracy, fair_nn_accuracy])"
]
}
],
diff --git a/site/zh-cn/quantum/tutorials/noise.ipynb b/site/zh-cn/quantum/tutorials/noise.ipynb
index af49c630bc..8b9389d29c 100644
--- a/site/zh-cn/quantum/tutorials/noise.ipynb
+++ b/site/zh-cn/quantum/tutorials/noise.ipynb
@@ -47,8 +47,7 @@
},
"source": [
"