目录
  1. 1. 一、TensorFlow 2.x 架构概览
  2. 2. 二、Tensor 基础与 Eager Execution
    1. 2.1. 2.1 Tensor 创建
    2. 2.2. 2.2 Tensor 属性与类型
    3. 2.3. 2.3 设备管理
  3. 3. 三、自动求导:tf.GradientTape
    1. 3.1. 3.1 基础用法
    2. 3.2. 3.2 关键参数
    3. 3.3. 3.3 高阶梯度
    4. 3.4. 3.4 梯度控制
  4. 4. 四、tf.function 与 AutoGraph
    1. 4.1. 4.1 基本使用
    2. 4.2. 4.2 Tracing 机制详解
    3. 4.3. 4.3 input_signature 参数
    4. 4.4. 4.4 AutoGraph:Python 代码到 TF 图的转换
    5. 4.5. 4.5 tf.py_function:Python 代码的逃生舱
    6. 4.6. 4.6 tf.Variable 在 tf.function 中的行为
  5. 5. 五、tf.data 数据管道
    1. 5.1. 5.1 Dataset 创建
    2. 5.2. 5.2 数据变换操作
    3. 5.3. 5.3 自定义 map 函数的注意事项
    4. 5.4. 5.4 TFRecord 格式详解
  6. 6. 六、Keras 集成
    1. 6.1. 6.1 三种建模方式
    2. 6.2. 6.2 自定义训练循环
    3. 6.3. 6.3 add_loss 与 add_metric
    4. 6.4. 6.4 Callbacks 系统
  7. 7. 七、分布式训练策略
    1. 7.1. 7.1 MirroredStrategy(单机多卡)
    2. 7.2. 7.2 MultiWorkerMirroredStrategy(多机多卡)
    3. 7.3. 7.3 TPUStrategy
    4. 7.4. 7.4 自定义分布式训练循环
    5. 7.5. 7.5 ParameterServerStrategy(参数服务器)
  8. 8. 八、TF Serving 模型部署
    1. 8.1. 8.1 SavedModel 导出
    2. 8.2. 8.2 Signatures 详解
    3. 8.3. 8.3 TF Serving 部署
    4. 8.4. 8.4 批处理配置
  9. 9. 九、TF Lite 移动端部署
    1. 9.1. 9.1 模型转换
    2. 9.2. 9.2 量化策略
    3. 9.3. 9.3 Delegate 加速
    4. 9.4. 9.4 Metadata 与 Model Card
    5. 9.5. 9.5 TF Lite 推理
  10. 10. 十、CIFAR-10 完整训练 Pipeline
  11. 11. 十一、调试与性能优化
    1. 11.1. 11.1 常用调试工具
    2. 11.2. 11.2 性能建议
  12. 12. 十二、TF 生态系统
    1. 12.1. 12.1 TensorFlow Extended (TFX)
    2. 12.2. 12.2 TensorBoard 可视化
机器学习框架篇-Tensorflow

TensorFlow 是 Google 开发的开源端到端机器学习平台。本文聚焦 TensorFlow 2.x 架构,从 Eager Execution、tf.function 自动图构建、tf.data 数据管道、Keras 高层 API,到分布式策略、TF Serving 模型部署和 TF Lite 端侧推理,构建完整的 TensorFlow 知识体系。

一、TensorFlow 2.x 架构概览

TensorFlow 2.x 相比 1.x 做了根本性的设计变革:默认启用 Eager Execution(动态图),以 Keras 作为统一高层 API,废弃了 tf.Sessiontf.placeholdertf.global_variables_initializer() 等 1.x 的繁琐概念。

import tensorflow as tf

# TF 2.x 的核心设计原则:
# 1. Eager Execution by default — 命令式编程,即时执行
# 2. Keras as the unified high-level API
# 3. tf.function 将 Python 代码转为高性能计算图
# 4. tf.data 作为标准数据输入管道

print(tf.__version__) # 如 2.15.0
print(tf.executing_eagerly()) # True (默认)

TF 1.x vs 2.x 对比:

特性 TF 1.x TF 2.x
执行模式 静态图(先构图后执行) 动态图(即时执行)
Session tf.Session().run() 直接调用
占位符 tf.placeholder 无需(函数参数替代)
变量初始化 global_variables_initializer 创建即初始化
高层 API tf.layers / tf.estimator tf.keras
图构建 手动 tf.function 自动

二、Tensor 基础与 Eager Execution

2.1 Tensor 创建

import tensorflow as tf
import numpy as np

# 常量(不可变)
a = tf.constant([1, 2, 3]) # int32
b = tf.constant([1.0, 2.0, 3.0]) # float32
c = tf.constant([[1, 2], [3, 4]]) # 2D
d = tf.constant(np.array([1, 2, 3])) # 从 numpy 创建

# 特殊值
zeros = tf.zeros([2, 3])
ones = tf.ones([3, 4])
eye = tf.eye(3)
fill = tf.fill([2, 3], 9.0)

# 序列
r = tf.range(0, 10, 2) # [0, 2, 4, 6, 8]
l = tf.linspace(0.0, 1.0, 5) # [0.0, 0.25, 0.5, 0.75, 1.0]

# 随机
rand = tf.random.uniform([2, 3], minval=0, maxval=1)
randn = tf.random.normal([2, 3], mean=0, stddev=1)
randint = tf.random.uniform([3, 3], minval=0, maxval=10, dtype=tf.int32)
trunc_norm = tf.random.truncated_normal([2, 3], mean=0, stddev=1)
# truncated_normal: 丢弃超过 2 个标准差的采样(防止极端值)

# 随机种子设置
tf.random.set_seed(42)

变量(Variable):可修改的 Tensor,用于存储模型权重:

# 创建变量
v = tf.Variable(tf.random.normal([3, 3]))
w = tf.Variable(tf.zeros([10, 5]), name='weights')
b = tf.Variable(tf.zeros([5]), name='bias')

# 属性
print(v.shape) # (3, 3)
print(v.dtype) # <dtype: 'float32'>
print(v.device) # /job:localhost/replica:0/task:0/device:CPU:0
print(v.trainable) # True (默认参与梯度更新)
print(v.name) # Variable:0

# 修改值
v.assign(tf.ones([3, 3])) # 赋值
v.assign_add(tf.ones([3, 3])) # 累加
v.assign_sub(tf.ones([3, 3])) # 累减

# trainable 属性控制是否被 optimizer 追踪
v2 = tf.Variable(tf.ones([3]), trainable=False)
# 在 model.trainable_variables 中不会出现 v2

2.2 Tensor 属性与类型

# dtype 体系
# tf.float32 / tf.float64 / tf.float16 / tf.bfloat16
# tf.int32 / tf.int64 / tf.int16 / tf.int8 / tf.uint8
# tf.bool / tf.string / tf.complex64 / tf.complex128

# 类型转换
x = tf.constant([1, 2, 3], dtype=tf.int32)
x_float = tf.cast(x, tf.float32)

# 形状操作
x = tf.ones([2, 3, 4])
tf.shape(x) # <tf.Tensor: shape=(3,), dtype=int32, numpy=array([2,3,4])>
tf.rank(x) # 3
tf.size(x) # 24 = 2*3*4
tf.reshape(x, [6, 4]) # reshape
tf.expand_dims(x, axis=0) # (1, 2, 3, 4)
tf.squeeze(tf.expand_dims(x, 0)) # 去除大小为 1 的维
tf.transpose(x, [2, 0, 1]) # 维度置换

# 拼接与分割
a = tf.ones([2, 3])
b = tf.zeros([2, 3])
c = tf.concat([a, b], axis=0) # (4, 3)
s = tf.stack([a, b], axis=0) # (2, 2, 3)
sp = tf.split(x, num_or_size_splits=2, axis=0) # 沿 axis 等分
unstack = tf.unstack(x, axis=0) # 拆分为 list

# 索引切片
x[0] # 第0行
x[:, 0:2] # 所有行,前两列
tf.gather(x, [0, 2]) # 收集第0和第2行
tf.gather_nd(x, [[0, 1], [1, 2]]) # 收集指定位置

# Broadcasting(广播)
a = tf.ones([3, 4])
b = tf.ones([4]) # shape (4,) → 广播为 (3, 4)
c = a + b # (3, 4)
tf.broadcast_to(b, [3, 4]) # 显式广播

2.3 设备管理

# 检查 GPU
print(tf.config.list_physical_devices('GPU'))
print(tf.test.is_gpu_available())
print(tf.config.experimental.get_device_details(
tf.config.list_physical_devices('GPU')[0]
))

# 指定设备
with tf.device('/CPU:0'):
x = tf.constant([1.0, 2.0])

with tf.device('/GPU:0'):
y = tf.constant([1.0, 2.0])

# tf.device 上下文管理器自动处理不存在的设备(回退到 CPU)

# 限制 GPU 显存增长(推荐)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
# 或设置显存上限
tf.config.experimental.set_virtual_device_configuration(
gpus[0],
[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)]
)

# 使用逻辑设备隔离
tf.config.set_logical_device_configuration(
gpus[0],
[tf.config.LogicalDeviceConfiguration(memory_limit=2048),
tf.config.LogicalDeviceConfiguration(memory_limit=2048)]
)
# 一个物理 GPU 被分为两个逻辑 GPU (GPU:0, GPU:1)

三、自动求导:tf.GradientTape

TensorFlow 2.x 使用磁带式自动微分(tape-based autodiff)。tf.GradientTape 记录前向运算,随后根据记录计算梯度。

3.1 基础用法

# 简单梯度
x = tf.Variable(3.0)
with tf.GradientTape() as tape:
y = x ** 2
dy_dx = tape.gradient(y, x)
print(dy_dx.numpy()) # 6.0 (= 2*3)

# 多变量梯度
w = tf.Variable(tf.random.normal([3, 2]))
b = tf.Variable(tf.zeros([2]))
with tf.GradientTape() as tape:
y = tf.matmul(tf.constant([[1.0, 2.0, 3.0]]), w) + b
loss = tf.reduce_mean(y)
grads = tape.gradient(loss, [w, b])
print(grads[0].shape) # (3, 2)
print(grads[1].shape) # (2,)

3.2 关键参数

# persistent=True:允许多次调用 tape.gradient
x = tf.Variable(2.0)
with tf.GradientTape(persistent=True) as tape:
y = x ** 3
z = x ** 2

dy_dx = tape.gradient(y, x) # 6*x^2 at x=2 → 24
dz_dx = tape.gradient(z, x) # 2*x at x=2 → 4
del tape # 用完后显式释放资源

# 默认 persistent=False,只能调用一次 gradient

# watch_accessed_variables=True (默认):自动追踪所有可训练变量
# watch_accessed_variables=False:需手动指定 watch 的 Tensor
x = tf.constant(2.0) # 常量不会被自动追踪
with tf.GradientTape() as tape:
tape.watch(x) # 手动 watch
y = x ** 3
dy_dx = tape.gradient(y, x) # 24.0

3.3 高阶梯度

x = tf.Variable(3.0)

# 二阶梯度
with tf.GradientTape() as outer_tape:
with tf.GradientTape() as inner_tape:
y = x ** 4
dy_dx = inner_tape.gradient(y, x) # 一阶: 4*x^3 = 108
d2y_dx2 = outer_tape.gradient(dy_dx, x) # 二阶: 12*x^2 = 108

# 使用 watch 的高阶梯度
x = tf.constant(3.0)
with tf.GradientTape() as t2:
t2.watch(x)
with tf.GradientTape() as t1:
t1.watch(x)
y = x ** 4
dy_dx = t1.gradient(y, x)
d2y_dx2 = t2.gradient(dy_dx, x)

3.4 梯度控制

# stop_gradient: 阻断梯度传播
x = tf.Variable(2.0)
with tf.GradientTape() as tape:
y = x ** 2
z = tf.stop_gradient(y) * x # y 部分梯度被阻断
dz_dx = tape.gradient(z, x) # = y (视为常数) = 4.0 (不包含 ∂y/∂x 项)

# tape.jacobian: 雅可比矩阵
x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
with tf.GradientTape() as tape:
tape.watch(x)
y = tf.reduce_sum(x ** 2) # 标量输出
jac = tape.jacobian(y, x)
print(jac) # [[2., 4.], [6., 8.]] (逐元素)

# tape.batch_jacobian: 批量雅可比
batch_x = tf.constant([[[1.0, 2.0]], [[3.0, 4.0]]]) # (2, 1, 2)
with tf.GradientTape() as tape:
tape.watch(batch_x)
y = batch_x ** 2 # (2, 1, 2)
batch_jac = tape.batch_jacobian(y, batch_x)
# 结果 shape: (2, 1, 2, 1, 2) — 每个样本独立计算

四、tf.function 与 AutoGraph

tf.function 是 TF 2.x 的核心特性之一,它将 Python 函数编译为高性能的 TensorFlow 计算图,兼顾了易用性和性能。

4.1 基本使用

# 装饰器方式
@tf.function
def simple_op(x, y):
return tf.matmul(x, y) + 1

# 直接调用方式
def my_op(x, y):
return tf.matmul(x, y) + 1
compiled_op = tf.function(my_op)

# 调用
x = tf.ones([3, 4])
y = tf.ones([4, 5])
result = simple_op(x, y) # 首次调用会 tracing,后续直接执行图
print(result.shape) # (3, 5)

# 查看生成的图表示
print(simple_op.get_concrete_function(x, y).graph.as_graph_def())
# 或使用 TensorBoard
# writer = tf.summary.create_file_writer('./logs')
# tf.summary.trace_on(graph=True)
# 调用函数
# with writer.as_default():
# tf.summary.trace_export(name="my_func", step=0)

4.2 Tracing 机制详解

# tf.function 的工作流程:
# 1. 首次调用 → Python Tracing → 生成 ConcreteFunction(计算图)→ 缓存
# 2. 后续调用 → 匹配缓存(按输入签名)→ 直接执行图

@tf.function
def add(a, b):
print("Tracing...") # 仅在 tracing 时执行一次
return a + b

# 第一次调用
add(tf.constant(1), tf.constant(2)) # 打印 "Tracing..."
# 第二次调用(同样 dtypes 和 shapes)
add(tf.constant(3), tf.constant(4)) # 不打印(使用缓存)
# 不同类型触发 retracing
add(tf.constant([1.0]), tf.constant([2.0])) # 打印 "Tracing..."(float vs int)

# retracing 的触发条件:
# 1. Tensor 参数 dtype 改变
# 2. Tensor 参数 rank(维数)改变
# 3. Python 标量参数值改变(除非用 tf.constant 封装)
# 4. 函数内部引用的全局变量改变
# 5. 函数内部引用的 tf.Variable 改变

# 避免不必要的 retracing:
# - 为输入指定 input_signature
# - 使用 tf.TensorSpec 固定输入格式

4.3 input_signature 参数

@tf.function(input_signature=[tf.TensorSpec(shape=[None, 3], dtype=tf.float32)])
def process(x):
return tf.reduce_sum(x, axis=1)

# 首次调用即确定签名,后续调用不会 retracing
process(tf.constant([[1.0, 2.0, 3.0]])) # batch=1
process(tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) # batch=2,不 retrace

# 获取函数的 ConcreteFunction
concrete_fn = process.get_concrete_function(
tf.TensorSpec(shape=[None, 3], dtype=tf.float32)
)

# 多个签名
@tf.function
def flexible_fn(x):
return tf.reduce_sum(x)

# 手动创建多个 ConcreteFunction
concrete_int = flexible_fn.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.int32)
)
concrete_float = flexible_fn.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.float32)
)

4.4 AutoGraph:Python 代码到 TF 图的转换

# AutoGraph 自动将 Python 控制流转为 tf.cond / tf.while_loop

@tf.function
def fizzbuzz(n):
result = tf.TensorArray(tf.string, size=n, dynamic_size=False)
for i in tf.range(n): # Python for → tf.while_loop
if i % 15 == 0: # Python if → tf.cond
result = result.write(i, 'FizzBuzz')
elif i % 3 == 0:
result = result.write(i, 'Fizz')
elif i % 5 == 0:
result = result.write(i, 'Buzz')
else:
result = result.write(i, tf.as_string(i))
return result.stack()

# 查看 AutoGraph 转换后的代码
print(tf.autograph.to_code(fizzbuzz.python_function))

# AutoGraph 支持的控制流转换:
# if/elif/else → tf.cond
# for x in y → tf.while_loop(如果 y 是 Tensor)
# while → tf.while_loop
# break/continue → 对应的条件控制
# return → 条件返回

# AutoGraph 不支持的 Python 特性:
# - generator/yield
# - try/except/finally(部分支持)
# - 对 Python list/dict 的动态修改
# - 对象的动态属性修改
# - 复杂的嵌套控制流

4.5 tf.py_function:Python 代码的逃生舱

# tf.py_function 允许在图执行中调用任意 Python 代码
# 代价:性能下降、无法跨平台(TF Serving/TF Lite 不支持)

def numpy_custom_op(x_numpy):
"""使用 NumPy 实现的复杂操作"""
result = np.fft.fft(x_numpy)
return result.astype(np.complex64)

@tf.function
def model_forward(x):
# 计算图中调用 NumPy 操作
y = tf.py_function(func=numpy_custom_op, inp=[x], Tout=tf.complex64)
y.set_shape(x.shape) # 手动设置 shape(py_function 不传播 shape)
return y

# tf.numpy_function:类似但更轻量,不记录梯度
# 用于非可微分的后处理操作

4.6 tf.Variable 在 tf.function 中的行为

# 规则:在 tf.function 外部创建的 tf.Variable 可以直接读写
# 在 tf.function 内部创建的 tf.Variable 仅在首次 tracing 时创建

counter = tf.Variable(0)

@tf.function
def increment():
counter.assign_add(1) # 修改外部变量
return counter.value()

print(increment()) # 1
print(increment()) # 2
print(increment()) # 3

# 内部创建变量(谨慎使用)
@tf.function
def lazy_init(x):
if not hasattr(lazy_init, 'w'):
lazy_init.w = tf.Variable(tf.random.normal([x.shape[-1], 10]))
return x @ lazy_init.w

# 变量捕获的作用域:
# - tf.function 追踪哪些外部变量被访问
# - 任何被访问的外部 Variable 都会成为图的隐式输入

五、tf.data 数据管道

tf.data 是构建高性能输入管道的标准 API,支持惰性加载、并行处理、预取等优化。

5.1 Dataset 创建

import tensorflow as tf

# 从 Tensor / NumPy 创建
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
# (features, labels) shape: (N, ...) → 每份 shape (...)

# 从生成器创建
def generator():
for i in range(1000):
yield (i, i * 2)

dataset = tf.data.Dataset.from_generator(
generator,
output_signature=(
tf.TensorSpec(shape=(), dtype=tf.int32),
tf.TensorSpec(shape=(), dtype=tf.int32)
)
)

# 从文件创建
dataset = tf.data.Dataset.list_files('/path/to/images/*.jpg')
dataset = tf.data.TextLineDataset('/path/to/text.txt')

# 从 TFRecord 创建(推荐的生产环境格式)
raw_dataset = tf.data.TFRecordDataset(['file1.tfrecord', 'file2.tfrecord'])

5.2 数据变换操作

dataset = tf.data.Dataset.range(1000)

# map: 逐元素变换
def preprocess(x):
return x * 2, x * 3
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
# num_parallel_calls: 并行调用数,AUTOTUNE 自动调优

# filter: 过滤
dataset = dataset.filter(lambda x, y: tf.reduce_sum(x) > 0)

# shuffle: 打乱(需 buffer_size,通常设为数据集大小)
dataset = dataset.shuffle(buffer_size=10000, seed=42)
# buffer_size 越大,随机性越强(但占内存越多)
# 设置 reshuffle_each_iteration=False 每个 epoch 不重新 shuffle

# batch: 分批
dataset = dataset.batch(32, drop_remainder=False)
# drop_remainder=True: 丢弃最后不完整的 batch

# prefetch: 预取(关键性能优化)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
# 在当前 batch 训练时,后台线程预取下一 batch
# AUTOTUNE 自动选择最优 buffer size

# repeat: 重复数据集
dataset = dataset.repeat(count=3) # 重复 3 次
dataset = dataset.repeat() # 无限重复

# interleave: 交错读取多个文件
files = tf.data.Dataset.list_files('*.tfrecord')
dataset = files.interleave(
tf.data.TFRecordDataset,
cycle_length=4, # 同时打开 4 个文件
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False # 非确定性提升性能
)

# cache: 缓存到内存/磁盘
dataset = dataset.cache() # 内存缓存
dataset = dataset.cache(filename='/tmp/cache') # 磁盘缓存(持久化)
# cache 应放在 shuffle 和 repeat 之前,避免重复读取

# 常用管道组合顺序:
# 1. map → shuffle → batch → prefetch (简单场景)
# 2. map → cache → shuffle → batch → repeat → prefetch (固定 epoch)
# 3. interleave → map → shuffle → batch → prefetch (文件读取)

5.3 自定义 map 函数的注意事项

# map 中的函数会被转为 tf.function,注意:
# - 避免使用 Python 原生数据类型
# - 使用 tf.io 而非 Python 的 open()
# - 使用 tf.image 而非 PIL

def parse_image(filename, label):
image_string = tf.io.read_file(filename) # 用 tf.io
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.resize(image, [224, 224])
image = tf.cast(image, tf.float32) / 255.0
return image, label

dataset = dataset.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)

# 随机增强在 map 中的处理
def augment(image, label):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
# 确保值域
image = tf.clip_by_value(image, 0.0, 1.0)
return image, label

# 仅在训练时增强
train_dataset = train_dataset.map(augment)
# val_dataset 则不应用 augment

5.4 TFRecord 格式详解

TFRecord 是 TensorFlow 推荐的高性能二进制数据格式:

# 写入 TFRecord
def serialize_example(feature0, feature1, label):
"""将数据序列化为 tf.train.Example"""
feature = {
'feature0': tf.train.Feature(
float_list=tf.train.FloatList(value=feature0.flatten())
),
'feature1': tf.train.Feature(
int64_list=tf.train.Int64List(value=feature1.flatten())
),
'label': tf.train.Feature(
int64_list=tf.train.Int64List(value=[label])
),
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
return example.SerializeToString()

# Feature 类型:
# tf.train.BytesList — 字节串(字符串、序列化的图像)
# tf.train.FloatList — float32 列表
# tf.train.Int64List — int64 列表

# 写入文件
with tf.io.TFRecordWriter('data.tfrecord') as writer:
for i in range(1000):
serialized = serialize_example(features[i], extra[i], labels[i])
writer.write(serialized)

# 压缩选项
options = tf.io.TFRecordOptions(compression_type='GZIP')
with tf.io.TFRecordWriter('data.tfrecord.gz', options) as writer:
writer.write(serialized)

# 读取 TFRecord
def parse_tfrecord_fn(example_proto):
"""解析 TFRecord 中的 Example"""
feature_description = {
'feature0': tf.io.FixedLenFeature([784], tf.float32),
'feature1': tf.io.FixedLenFeature([10], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
}
example = tf.io.parse_single_example(example_proto, feature_description)
return example['feature0'], example['feature1'], example['label']

# VarLenFeature 用于变长数据
# feature_description = {
# 'variable_feat': tf.io.VarLenFeature(tf.float32),
# }
# 解析结果为 SparseTensor

# 构建完整管道
dataset = tf.data.TFRecordDataset(
['file1.tfrecord', 'file2.tfrecord'],
compression_type='GZIP' # 或 '', 'ZLIB'
)
dataset = dataset.map(parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.shuffle(10000).batch(32).prefetch(tf.data.AUTOTUNE)

六、Keras 集成

TensorFlow 2.x 将 Keras 作为官方高层 API。本节聚焦 tf.keras 的特有功能(纯 Keras 通用知识参见 “机器学习框架篇-Keras” 文章)。

6.1 三种建模方式

import tensorflow as tf

# 1. Sequential API — 简单层堆叠
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(10, activation='softmax')
])

# 2. Functional API — DAG 拓扑结构
inputs = tf.keras.Input(shape=(784,))
x = tf.keras.layers.Dense(256, activation='relu')(inputs)
x = tf.keras.layers.Dropout(0.5)(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

# 多输入/多输出
input_a = tf.keras.Input(shape=(28, 28, 1))
input_b = tf.keras.Input(shape=(10,))
x = tf.keras.layers.Flatten()(input_a)
x = tf.keras.layers.Concatenate()([x, input_b])
out1 = tf.keras.layers.Dense(10, name='class')(x)
out2 = tf.keras.layers.Dense(1, name='aux')(x)
model = tf.keras.Model(inputs=[input_a, input_b], outputs=[out1, out2])
model.compile(optimizer='adam',
loss={'class': 'sparse_categorical_crossentropy', 'aux': 'mse'},
loss_weights={'class': 1.0, 'aux': 0.5})

# 3. Subclassing API — 完全灵活
class CustomModel(tf.keras.Model):
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
self.conv2 = tf.keras.layers.Conv2D(64, 3, activation='relu')
self.flatten = tf.keras.layers.Flatten()
self.dense = tf.keras.layers.Dense(num_classes)

def call(self, inputs, training=None):
x = self.conv1(inputs)
x = self.conv2(x)
x = self.flatten(x)
return self.dense(x)

6.2 自定义训练循环

# 覆盖 train_step 实现自定义训练逻辑
class CustomFitModel(tf.keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.loss_tracker = tf.keras.metrics.Mean(name='loss')
self.accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='acc')

def train_step(self, data):
x, y = data

with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

gradients = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

# 更新 metrics
self.loss_tracker.update_state(loss)
self.accuracy.update_state(y, y_pred)
return {'loss': self.loss_tracker.result(), 'acc': self.accuracy.result()}

def test_step(self, data):
x, y = data
y_pred = self(x, training=False)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

self.loss_tracker.update_state(loss)
self.accuracy.update_state(y, y_pred)
return {'loss': self.loss_tracker.result(), 'acc': self.accuracy.result()}

@property
def metrics(self):
return [self.loss_tracker, self.accuracy]

# 完全手动训练循环(不使用 model.fit)
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(model, x, y, optimizer, loss_fn):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss = loss_fn(y, logits)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss

for epoch in range(epochs):
for step, (x_batch, y_batch) in enumerate(train_dataset):
loss = train_step(model, x_batch, y_batch, optimizer, loss_fn)

6.3 add_loss 与 add_metric

class RegularizedModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense = tf.keras.layers.Dense(10,
kernel_regularizer=tf.keras.regularizers.l2(0.01),
activity_regularizer=tf.keras.regularizers.l1(0.001)
)

def call(self, inputs, training=None):
x = self.dense(inputs)
# 添加自定义损失
self.add_loss(tf.reduce_sum(x) * 0.001, inputs=inputs)
# add_loss 会在 model.losses 中自动累加
return x

# 自定义 metric 示例
class F1Score(tf.keras.metrics.Metric):
def __init__(self, name='f1_score', **kwargs):
super().__init__(name=name, **kwargs)
self.true_positives = self.add_weight(name='tp', initializer='zeros')
self.false_positives = self.add_weight(name='fp', initializer='zeros')
self.false_negatives = self.add_weight(name='fn', initializer='zeros')

def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.argmax(y_pred, axis=1)
y_true = tf.cast(y_true, tf.int64)
tp = tf.reduce_sum(tf.cast(tf.equal(y_pred, 1) & tf.equal(y_true, 1), tf.float32))
fp = tf.reduce_sum(tf.cast(tf.equal(y_pred, 1) & tf.equal(y_true, 0), tf.float32))
fn = tf.reduce_sum(tf.cast(tf.equal(y_pred, 0) & tf.equal(y_true, 1), tf.float32))
if sample_weight is not None:
tp *= sample_weight
fp *= sample_weight
fn *= sample_weight
self.true_positives.assign_add(tp)
self.false_positives.assign_add(fp)
self.false_negatives.assign_add(fn)

def result(self):
precision = self.true_positives / (self.true_positives + self.false_positives + 1e-7)
recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-7)
return 2 * precision * recall / (precision + recall + 1e-7)

def reset_state(self):
self.true_positives.assign(0)
self.false_positives.assign(0)
self.false_negatives.assign(0)

6.4 Callbacks 系统

# 内置 Callbacks
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
monitor='val_loss', patience=10, mode='min',
restore_best_weights=True, min_delta=0.001,
baseline=None, verbose=1
),

# 模型检查点
tf.keras.callbacks.ModelCheckpoint(
filepath='checkpoints/model-{epoch:02d}-{val_loss:.2f}.h5',
monitor='val_accuracy', mode='max',
save_best_only=True, save_weights_only=False,
save_freq='epoch', verbose=1
),

# TensorBoard
tf.keras.callbacks.TensorBoard(
log_dir='./logs',
histogram_freq=1, # 每 epoch 记录直方图
write_graph=True, # 记录计算图
write_images=False,
update_freq='epoch', # 或整数表示 batch 频率
profile_batch='500,520', # 性能分析区间
embeddings_freq=1,
),

# ReduceLROnPlateau
tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss', factor=0.2,
patience=5, min_lr=1e-7, mode='min',
cooldown=2, verbose=1
),

# CSV 日志
tf.keras.callbacks.CSVLogger('training_log.csv', separator=',', append=False),

# TerminateOnNaN: 遇到 NaN loss 时终止训练
tf.keras.callbacks.TerminateOnNaN(),
]

# 自定义 Callback
class CustomLogger(tf.keras.callbacks.Callback):
def __init__(self):
super().__init__()
self.epoch_times = []

def on_epoch_begin(self, epoch, logs=None):
self.epoch_start = time.time()

def on_epoch_end(self, epoch, logs=None):
elapsed = time.time() - self.epoch_start
self.epoch_times.append(elapsed)
print(f'\nEpoch {epoch+1} took {elapsed:.1f}s - '
f'loss: {logs["loss"]:.4f} - val_loss: {logs.get("val_loss", "N/A")}')

def on_batch_end(self, batch, logs=None):
if batch % 100 == 0:
print(f' Batch {batch}: loss={logs["loss"]:.4f}')

def on_train_end(self, logs=None):
print(f'Training completed. Avg epoch time: '
f'{np.mean(self.epoch_times):.1f}s')

七、分布式训练策略

TF 2.x 使用 tf.distribute.Strategy 统一分布式训练 API。

7.1 MirroredStrategy(单机多卡)

# MirroredStrategy: 同步训练,每张 GPU 有完整模型副本

strategy = tf.distribute.MirroredStrategy(
devices=None, # None=所有可用 GPU
cross_device_ops=None # None=自动选择(NCCL for GPU)
)

print(f'Number of devices: {strategy.num_replicas_in_sync}')

# 模型需在 strategy.scope() 中创建
with strategy.scope():
model = create_model()
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# learning rate 会被自动缩放(按 replica 数)— 线性缩放规则

# 数据集
# 使用分布式数据集(自动分片)
dataset = tf.data.Dataset.from_tensor_slices((x, y)).batch(GLOBAL_BATCH_SIZE)
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
dataset = dataset.with_options(options)

# 训练
model.fit(train_dataset, epochs=10)

7.2 MultiWorkerMirroredStrategy(多机多卡)

# 多机分布式训练
import os
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ['192.168.1.10:12345', '192.168.1.11:12345', '192.168.1.12:12345']
},
'task': {'type': 'worker', 'index': 0} # 当前节点索引
})

strategy = tf.distribute.MultiWorkerMirroredStrategy(
communication_options=tf.distribute.experimental.CommunicationOptions(
implementation=tf.distribute.experimental.CommunicationImplementation.NCCL
)
)

with strategy.scope():
model = create_model()
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

# 使用 BackupAndRestore callback 进行容错
callback = tf.keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup')
model.fit(dataset, epochs=10, callbacks=[callback])

7.3 TPUStrategy

# TPU 训练配置
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)

strategy = tf.distribute.TPUStrategy(resolver)

with strategy.scope():
model = create_model()
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# TPU 注意事项:
# - 静态 shape 要求更严格
# - batch_size 必须是 TPU core 数的倍数(通常 8 的倍数)
# - 避免 tf.py_function
# - 首选 tf.float32(bfloat16 在某些 TPU 上作为内部格式)
# - 使用 tf.data.experimental.AUTOTUNE 优化管道

7.4 自定义分布式训练循环

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
model = create_model()
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# reduction.NONE 是关键,保留 per-replica loss

def compute_loss(y_true, y_pred):
per_example_loss = loss_fn(y_true, y_pred)
return tf.nn.compute_average_loss(
per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE
)

train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def distributed_train_step(dataset_inputs):
def step_fn(inputs):
images, labels = inputs
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = compute_loss(labels, predictions)

grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
train_accuracy.update_state(labels, predictions)
return loss

per_replica_losses = strategy.run(step_fn, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

for epoch in range(epochs):
total_loss = 0.0
num_batches = 0
for x in train_dist_dataset:
total_loss += distributed_train_step(x)
num_batches += 1
print(f'Epoch {epoch}: loss={total_loss/num_batches:.4f}')

7.5 ParameterServerStrategy(参数服务器)

# 适合大规模异步训练(推荐使用 TF 2.9+ 的内置支持)
# 配置
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ['worker0:12345', 'worker1:12345'],
'ps': ['ps0:12345']
},
'task': {'type': 'worker', 'index': 0}
})

strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=TFConfigClusterResolver()
)

# 使用 Model.fit 或自定义循环,框架自动处理变量分片

八、TF Serving 模型部署

8.1 SavedModel 导出

# 方式一:model.save() (Keras 模型)
model.save('/tmp/my_model/1/') # 自动保存为 SavedModel 格式
# 目录结构:
# /tmp/my_model/1/
# saved_model.pb — 序列化的计算图
# variables/ — 模型权重(variables.index + variables.data-00000-of-00001)
# assets/ — 额外文件(词表等)
# keras_metadata.pb — Keras 元数据

# 方式二:tf.saved_model.save (底层 API)
tf.saved_model.save(
model,
'/tmp/my_model/1/',
signatures=model.call.get_concrete_function(
tf.TensorSpec(shape=[None, 28, 28, 1], dtype=tf.float32, name='input')
)
)

# 方式三:save_format='tf' (HDF5 是另一种选项但不推荐用于 Serving)
model.save('my_model.h5', save_format='h5') # HDF5 格式
model.save('my_model', save_format='tf') # SavedModel 格式(默认)

8.2 Signatures 详解

class ServingModel(tf.keras.Model):
def __init__(self, model):
super().__init__()
self.model = model

@tf.function(input_signature=[
tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8, name='image_bytes'),
tf.TensorSpec(shape=[None, 100], dtype=tf.float32, name='text_embedding')
])
def call(self, images, text_embedding):
# 预处理
images = tf.cast(images, tf.float32) / 255.0
images = tf.image.resize(images, [224, 224])
# 推理
return self.model({'image': images, 'text': text_embedding})

# 保存时带多个签名
@tf.function(input_signature=[tf.TensorSpec([None, 784], tf.float32)])
def serving_fn(inputs):
return model(inputs)

@tf.function(input_signature=[tf.TensorSpec([None, 784], tf.float32)])
def serving_fn_embedding(inputs):
return model.get_layer('embedding')(inputs)

signatures = {
'serving_default': serving_fn,
'embedding': serving_fn_embedding,
}
tf.saved_model.save(model, '/tmp/model/1/', signatures=signatures)

# 加载 SavedModel
loaded = tf.saved_model.load('/tmp/model/1/')
print(list(loaded.signatures.keys())) # ['serving_default', 'embedding']
infer = loaded.signatures['serving_default']
result = infer(tf.constant([[1.0] * 784]))

8.3 TF Serving 部署

# Docker 方式启动 TF Serving
docker run -p 8501:8501 \
--name tf_serving \
--mount type=bind,source=/tmp/my_model,target=/models/my_model \
-e MODEL_NAME=my_model \
-t tensorflow/serving

# REST API 请求
curl -X POST http://localhost:8501/v1/models/my_model:predict \
-H "Content-Type: application/json" \
-d '{
"instances": [[1.0, 2.0, 3.0]]
}'

# gRPC API (Python 客户端)
import grpc
from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc

channel = grpc.insecure_channel('localhost:8500')
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)

request = predict_pb2.PredictRequest()
request.model_spec.name = 'my_model'
request.model_spec.signature_name = 'serving_default'
request.inputs['input'].CopyFrom(tf.make_tensor_proto(data))

response = stub.Predict(request, timeout=10.0)
result = tf.make_ndarray(response.outputs['output'])

8.4 批处理配置

# model_config 配置批处理(在 TF Serving 配置文件中)
# --model_config_file=models.config

# models.config:
# model_config_list {
# config {
# name: 'my_model'
# base_path: '/models/my_model'
# model_platform: 'tensorflow'
# model_version_policy { specific { versions: [1] } }
# }
# }

# 启用动态批处理(TF Serving 的 batching 配置)
# --enable_batching --batching_parameters_file=batching.config

# batching.config:
# max_batch_size { value: 128 }
# batch_timeout_micros { value: 5000 } # 5ms 超时
# max_enqueued_batches { value: 100 }
# num_batch_threads { value: 4 }

九、TF Lite 移动端部署

9.1 模型转换

import tensorflow as tf

# 加载 Keras 模型
model = tf.keras.models.load_model('model.h5')

# 转换为 TF Lite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open('model.tflite', 'wb') as f:
f.write(tflite_model)

# 从 SavedModel 转换
converter = tf.lite.TFLiteConverter.from_saved_model('/tmp/model/1/')
tflite_model = converter.convert()

# 从 ConcreteFunction 转换
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[concrete_func], trackable_obj=model
)

9.2 量化策略

# 1. 动态范围量化(Dynamic Range Quantization)
# ——只量化权重,激活保持 float32,模型缩减 4x
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# 2. 全整数量化(Full Integer Quantization)
# ——权重和激活都量化为 int8,需要代表性数据集
def representative_dataset():
for data in calibration_dataset.take(100):
yield [tf.cast(data, tf.float32)]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8 # 或 tf.uint8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()

# 3. Float16 量化
# ——权重降为 float16,适合 GPU delegate
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()

# 4. 整数量化 + float 回退
# 不支持整数的 op 自动回退到 float
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
tf.lite.OpsSet.TFLITE_BUILTINS # float 回退
]

9.3 Delegate 加速

# 在推理时使用 Delegate 硬件加速

# GPU Delegate (Android/iOS)
import tflite_runtime.interpreter as tflite

# 加载 delegate
gpu_delegate = tflite.load_delegate('libtensorflowlite_gpu_delegate.so')
# 或使用 experimental delegate
gpu_options = tf.lite.experimental.GPUDelegateOptions()
gpu_delegate = tf.lite.experimental.load_delegate(gpu_options)

interpreter = tf.lite.Interpreter(
model_path='model.tflite',
experimental_delegates=[gpu_delegate]
)

# NNAPI Delegate (Android Neural Networks API)
interpreter = tf.lite.Interpreter(
model_path='model.tflite',
experimental_delegates=[tf.lite.experimental.load_delegate('libtensorflowlite_nnapi_delegate.so')]
)

# XNNPACK Delegate (ARM/x86 CPU 加速,默认已启用)
# TF Lite 2.3+ 默认集成了 XNNPACK
# 手动配置:
# interpreter_options = tf.lite.InterpreterOptions()
# interpreter_options.set_xnnpack_enabled(True)

# Core ML Delegate (iOS)
# 需要安装 tensorflow-lite-select-ios
interpreter = tf.lite.Interpreter(
model_path='model.tflite'
)
# CoreML delegate 由 build 时配置

# Hexagon DSP Delegate (Qualcomm)
# 使用 Hexagon NN 库加速

9.4 Metadata 与 Model Card

from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb

# 添加模型元数据
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "Image Classifier"
model_meta.description = "Classify images into 1000 categories"
model_meta.version = "v1.0"
model_meta.author = "Your Name"
model_meta.license = "Apache 2.0"

# 输入元数据
input_meta = _metadata_fb.TensorMetadataT()
input_meta.name = "image"
input_meta.description = "Input image to be classified."
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties

# 关联标签文件
input_meta.associatedFiles = [
_metadata_fb.AssociatedFileT()
]
input_meta.associatedFiles[0].name = "labels.txt"
input_meta.associatedFiles[0].description = "Labels for classification outputs"
input_meta.associatedFiles[0].type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS

# 写入元数据到 TFLite 模型
populator = _metadata.MetadataPopulator.with_model_file('model.tflite')
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files(['labels.txt'])
populator.populate()

9.5 TF Lite 推理

# Python 中推理(测试用途)
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()

# 获取输入输出详情
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(f"Input: {input_details[0]['shape']}, dtype: {input_details[0]['dtype']}")
print(f"Output: {output_details[0]['shape']}, dtype: {output_details[0]['dtype']}")

# 设置输入
input_data = np.array(preprocessed_image, dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

# 推理
interpreter.invoke()

# 获取输出
output_data = interpreter.get_tensor(output_details[0]['index'])

十、CIFAR-10 完整训练 Pipeline

以下是一个整合了数据增强、混合精度、分布式训练和模型导出的完整 CIFAR-10 训练示例:

import tensorflow as tf
import numpy as np
import os

# === 配置 ===
AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 128
EPOCHS = 50
LEARNING_RATE = 0.001
IMAGE_SIZE = (32, 32)

# 启用混合精度
tf.keras.mixed_precision.set_global_policy('mixed_float16')

# === 数据准备 ===
def load_cifar10():
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
return x_train, y_train, x_test, y_test

x_train, y_train, x_test, y_test = load_cifar10()

# 数据增强
data_augmentation = tf.keras.Sequential([
tf.keras.layers.RandomFlip('horizontal'),
tf.keras.layers.RandomRotation(0.1),
tf.keras.layers.RandomZoom(0.1),
tf.keras.layers.RandomTranslation(0.1, 0.1),
tf.keras.layers.RandomContrast(0.1),
], name='data_augmentation')

# 预处理函数
def preprocess(image, label, training=False):
image = tf.cast(image, tf.float32) / 255.0
if training:
image = data_augmentation(image, training=True)
image = tf.keras.applications.resnet_v2.preprocess_input(image * 255.0)
return image, label

# 构建 tf.data 管道
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(50000)
train_dataset = train_dataset.map(
lambda x, y: preprocess(x, y, training=True),
num_parallel_calls=AUTOTUNE
)
train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder=True)
train_dataset = train_dataset.prefetch(AUTOTUNE)

test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.map(
lambda x, y: preprocess(x, y, training=False),
num_parallel_calls=AUTOTUNE
)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.prefetch(AUTOTUNE)

# === 模型定义 ===
def create_model():
inputs = tf.keras.Input(shape=(32, 32, 3))

# 初始卷积(小图像不需大 kernel)
x = tf.keras.layers.Conv2D(64, 3, padding='same',
kernel_regularizer=tf.keras.regularizers.l2(1e-4))(inputs)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)

# ResNet 风格的残差块(小网络适配 CIFAR-10)
def residual_block(x, filters, stride=1):
shortcut = x
if stride != 1 or x.shape[-1] != filters:
shortcut = tf.keras.layers.Conv2D(
filters, 1, strides=stride, use_bias=False
)(shortcut)
shortcut = tf.keras.layers.BatchNormalization()(shortcut)

x = tf.keras.layers.Conv2D(filters, 3, strides=stride,
padding='same', use_bias=False)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Conv2D(filters, 3, padding='same',
use_bias=False)(x)
x = tf.keras.layers.BatchNormalization()(x)

x = tf.keras.layers.Add()([x, shortcut])
x = tf.keras.layers.ReLU()(x)
return x

# 构建网络
x = residual_block(x, 64)
x = residual_block(x, 64)
x = residual_block(x, 128, stride=2)
x = residual_block(x, 128)
x = residual_block(x, 256, stride=2)
x = residual_block(x, 256)

x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.3)(x)
outputs = tf.keras.layers.Dense(10, dtype='float32',
kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
# dtype='float32' 确保 softmax 精度(混合精度时最后层用 float32)

model = tf.keras.Model(inputs, outputs)
return model

# === 策略:单机多卡 ===
strategy = tf.distribute.MirroredStrategy()
print(f'Number of devices: {strategy.num_replicas_in_sync}')

with strategy.scope():
model = create_model()
model.compile(
optimizer=tf.keras.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=1e-4
),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top5_accuracy'),
]
)

# === 训练 ===
callbacks = [
tf.keras.callbacks.EarlyStopping(
monitor='val_accuracy', patience=10, restore_best_weights=True
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss', factor=0.5, patience=5, min_lr=1e-7
),
tf.keras.callbacks.ModelCheckpoint(
'checkpoints/cifar10_best.h5',
monitor='val_accuracy', save_best_only=True, mode='max'
),
tf.keras.callbacks.TensorBoard(log_dir='./logs/cifar10', histogram_freq=1),
]

history = model.fit(
train_dataset,
validation_data=test_dataset,
epochs=EPOCHS,
callbacks=callbacks,
verbose=1
)

# === 评估 ===
test_loss, test_acc, test_top5 = model.evaluate(test_dataset)
print(f'Test accuracy: {test_acc:.4f}, Top-5 accuracy: {test_top5:.4f}')

# === 导出 SavedModel ===
@tf.function(input_signature=[tf.TensorSpec(shape=[None, 32, 32, 3], dtype=tf.uint8)])
def serving_fn(image):
image = tf.cast(image, tf.float32) / 255.0
image = tf.keras.applications.resnet_v2.preprocess_input(image * 255.0)
logits = model(image, training=False)
probs = tf.nn.softmax(logits)
return {'probabilities': probs, 'class_ids': tf.argmax(probs, axis=-1)}

export_path = './export/cifar10_classifier/1'
tf.saved_model.save(
model,
export_path,
signatures={'serving_default': serving_fn}
)
print(f'Model saved to {export_path}')

# === TF Lite 导出(动态范围量化)===
converter = tf.lite.TFLiteConverter.from_saved_model(export_path)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with open('cifar10_model.tflite', 'wb') as f:
f.write(tflite_model)
print(f'TF Lite model size: {len(tflite_model) / 1024:.1f} KB')

十一、调试与性能优化

11.1 常用调试工具

# tf.debugging 模块
tf.debugging.assert_shapes([(x, (None, 784))])
tf.debugging.assert_positive(x)
tf.debugging.assert_near(x, y, atol=1e-4)
tf.debugging.check_numerics(loss, 'Loss is NaN or Inf')

# tf.print (在 tf.function 中使用)
@tf.function
def debug_fn(x):
tf.print('Shape:', tf.shape(x), 'Max:', tf.reduce_max(x))
return x * 2

# tf.config.run_functions_eagerly(True) — 关闭 tf.function
tf.config.run_functions_eagerly(True)
# 调试完成后恢复:
tf.config.run_functions_eagerly(False)

# TensorBoard Profiler
# 启动:
# tensorboard --logdir=./logs
# 代码中:
# tf.profiler.experimental.start('./logs/profiler')
# ... 训练代码 ...
# tf.profiler.experimental.stop()

11.2 性能建议

# 1. 使用 tf.function 加速训练步骤
# 2. 使用 tf.data.AUTOTUNE 自动调优
# 3. 使用 prefetch 重叠数据加载和计算
# 4. 使用 num_parallel_calls=tf.data.AUTOTUNE
# 5. 将数据格式化为 TFRecord
# 6. 使用混合精度训练
# 7. 启用 XLA 编译(实验性)
# tf.config.optimizer.set_jit(True)
# 或 @tf.function(jit_compile=True)
# 8. 使用 tf.keras.layers.experimental.preprocessing
# 替代在 map 中进行预处理(GPU 加速)

十二、TF 生态系统

12.1 TensorFlow Extended (TFX)

TFX 是生产级 ML 管道的端到端平台,核心组件包括:

  • ExampleGen: 数据摄入
  • StatisticsGen: 生成数据统计(配合 TFDV)
  • SchemaGen: 生成数据模式
  • Transform: 特征工程(全量 + 在线一致)
  • Trainer: 模型训练
  • Evaluator: 模型验证(配合 TFMA)
  • Pusher: 模型部署

12.2 TensorBoard 可视化

# 按目录组织日志
log_dir = f"logs/{model_name}/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

# TensorBoard callback
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir, histogram_freq=1, write_graph=True,
update_freq='epoch', profile_batch='10,20'
)

# 自定义 summary
train_summary_writer = tf.summary.create_file_writer(log_dir)
with train_summary_writer.as_default():
tf.summary.scalar('custom_metric', value, step=global_step)
tf.summary.image('input_images', images, step=global_step, max_outputs=8)
tf.summary.histogram('weights', weights, step=global_step)

本篇文章系统地梳理了 TensorFlow 2.x 的核心技术栈,从基础的 Eager Execution 和 GradientTape,到 tf.function 自动图编译、tf.data 高性能数据管道、分布式训练策略,再到 TF Serving 和 TF Lite 的模型部署。掌握这些知识后,读者能够独立完成从模型开发到生产部署的完整流程。TensorFlow 与 PyTorch 各有优势,建议根据具体场景选择:TF 在生产部署(TF Serving/TF Lite)和移动端方面有更成熟的生态,而 PyTorch 在研究和灵活性方面更为突出。

文章作者: Leo·Cheung
文章链接: http://tufusi.com/2022/04/20/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E6%A1%86%E6%9E%B6%E7%AF%87-Tensorflow/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 ONE·PIECE
打赏
  • 微信
  • 支付宝

评论