在我们的很多推荐业务场景中,通常一个模型可能是一直不断增量训练的,如果哪天业务需要调整模型结构,去训练一个新模型,但是又不想完全从0开始,希望复用原来模型里面的部分参数,这样冷启动的代价就小很多了。
实际上 TensorFlow 提供了足够的灵活性,我们可以控制从其他模型 restore 部分参数到新的模型里。因为目前生产环境普遍还是在用 tf-1.x,下面分别介绍Low-Level API 和 Estimator API 两种实践。
Low-Level API 实践
在决定从已有的模型预热参数前,可以先将模型ckpt拉到本地,开一个 ipython 或 jupyter,列出模型中的所有参数。1
2
3
4
5
6
7
8
9
10
11
12
13In [6]: tf.train.list_variables(checkpoint_dir)
Out[6]:
[('dense/bias', [1]),
('dense/bias/Adagrad', [1]),
('dense/kernel', [17, 1]),
('dense/kernel/Adagrad', [17, 1]),
('fm/b', [1]),
('fm/b/Adagrad', [1]),
('fm/v', [4809162, 16]),
('fm/v/Adagrad', [4809162, 16]),
('fm/w', [4809162, 1]),
('fm/w/Adagrad', [4809162, 1]),
('global_step', [])]
假如,我们想要从ckpt中预热 fm/v
和 fm/w
两个参数,很简单,通过自定义一个 tf.train.Saver
来控制加载哪些参数:
1 | ... |
以上代码中 recover_vars
定义了要从 ckpt 中恢复的参数,是一个字典形式,key 为 ckpt 中的变量名,从上面我们 list 出来的变量里找即可,value 为要覆盖的变量,即从 ckpt 中找到名字为 key 的变量参数,去覆盖 value 指定的变量。
Estimator API 实践
如果你是用高阶 Estimator API,其实完全可以借助 Estimator 自带的 warm_start 功能来实现。
1 | tf.estimator.WarmStartSettings( |
- ckpt_to_initialize_from:预热模型的ckpt路径
- vars_to_warm_start:要加载哪些变量出来预热,可以通过上述
tf.train.list_variables
方法先列出变量名再决定要哪些变量 - var_name_to_vocab_info:动态词表信息
- var_name_to_prev_var_name:新模型中的变量名 -> 旧模型中的变量名,意思就是加载出来的变量会预热到新模型的变量
如果旧模型中有变量A,新模型有变量A、B,需要将旧模型的变量A恢复到新模型的变量B,如果使用tf warm_start,它既会将旧模型的变量A恢复到新模型的变量B,也会恢复到新模型的变量A。为解决名字冲突问题,我们可以自定义一个 Hook 将上述 low-level api 的使用方式封装一下,实现定制化恢复即可。
1 | class VariableRecoverHook(tf.train.SessionRunHook): |
以上代码实现一个 Hook,其中初始化参数 recover_vars
表示要从 ckpt 中恢复的参数。一般在恢复参数前,也需要list一下旧模型中的参数,找到对应的变量名。有了这个 Hook 后,那么我们就可以在 model_fn
中插入这个 Hook 的实例即可。
1 | def model_fn(features, labels, mode, params): |
小结
实践中我们可以基于 TensorFlow 灵活保存以及恢复参数,当有迁移学习需求时,可以通过定制化 tf.train.Saver
的方式来控制预热指定的参数。目前 TensorFlow 也进入 2.x 时代了,官方主推Keras API,通过 Keras API 可以更加灵活的控制保存以及恢复参数。但是如果你是用 Estimator,则可以直接复用 warm_start 或上述 Hook 实现。
转载请注明出处,本文永久链接:https://sharkdtu.github.io/posts/tf-migrate-learning.html