【云+社区年度征文】tensorflow 2.0 Estimator Keras读取saved model并预测

背景

使用tensorflow2.0以上版本框架用Keras或者Estimator方式保存模型有两种方式加载模型并预测。

Keras框架保存模型后可以直接加载并调用predict方法预测;

estimator将比较麻烦,需要签名并传入tensor才可以预测;

Keras模型预测

代码语言:txt
复制
import tensorflow as tf
from tensorflow import keras
model = tf.keras.models.load_model(export_dir)

dataframe 特征读取与处理

X = dict(dataframe)
c = model.predict(X)
output = np.argmax(c, axis=1)

Estimator模型预测

代码语言:txt
复制
import tensorflow as tf

加载模型 & 签名

imported = tf.saved_model.load(export_dir)
f = imported.signatures["predict"]

代码语言:txt
复制
# 转换为tensor并预测
out_df = pd.DataFrame()
def predict(dataframe):
examples = []
for row in dataframe.itertuples():
feature_map = {}
# 特征处理 将特征放入dict中
example = tf.train.Example(
features=tf.train.Features(
feature = feature_map
)
)
examples.append(example.SerializeToString())

ex = tf.constant(examples)
result = f(examples=ex)
out_df['high_rank_score'] = np.max(result["probabilities"].numpy(), axis=1)
out_df['tag'] = np.argmax(result["probabilities"].numpy(), axis=1)
return out_df</code></pre></div></div><h4 id="7ihuj" name="Ref">Ref</h4><ol class="ol-level-0"><li>http://d0evi1.com/tensorflow/custom_estimators/</li><li>https://www.tensorflow.org/guide/saved_model?hl=zh-cn#%E5%8A%A0%E8%BD%BD%E5%92%8C%E4%BD%BF%E7%94%A8%E8%87%AA%E5%AE%9A%E4%B9%89%E6%A8%A1%E5%9E%8B</li><li>https://zhuanlan.zhihu.com/p/66872472</li><li>https://yinguobing.com/load-savedmodel-of-estimator-by-keras/</li></ol>