博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Keras(十五)tf_record基础API使用
阅读量:4202 次
发布时间:2019-05-26

本文共 3736 字,大约阅读时间需要 12 分钟。

本文将介绍:

  • 构建一个tf.train.Example对象
  • 将tf.train.Example对象存入文件中,生成的tf_record文件.
  • 使用tf.data的API读取tf_record文件,并实现反序列化
  • 将tf.train.Example对象存入压缩文件中,生成的tf_record压缩文件.
  • 使用tf.data的API读取tf_record压缩文件.

一,构建一个ByteList,FloatList,Int64List的对象

#!/usr/bin/env python3# -*- coding: utf-8 -*-import matplotlib as mplimport matplotlib.pyplot as pltimport numpy as npimport sklearnimport pandas as pdimport osimport sysimport timeimport tensorflow as tffrom tensorflow import keras# 打印使用的python库的版本信息print(tf.__version__)print(sys.version_info)for module in mpl, np, pd, sklearn, tf, keras:    print(module.__name__, module.__version__)    # tfrecord 是一种文件格式# -> tf.train.Example可以是一个样本或者一组#    每个Example-> tf.train.Features -> {"key": tf.train.Feature}#       每个Feature有不同的格式-> tf.train.Feature -> tf.train.ByteList/FloatList/Int64List# 1,构建一个ByteList,FloatList,Int64List的对象# tf.train.ByteListfavorite_books = [name.encode('utf-8') for name in ["machine learning", "cc150"]]favorite_books_bytelist = tf.train.BytesList(value = favorite_books)print(favorite_books_bytelist,type(favorite_books_bytelist))# tf.train.FloatListhours_floatlist = tf.train.FloatList(value = [15.5, 9.5, 7.0, 8.0])print(hours_floatlist,type(hours_floatlist))# tf.train.Int64Listage_int64list = tf.train.Int64List(value = [42])print(age_int64list,type(age_int64list))

二,构建一个tf.train.Features对象

# tf.train.Featuresfeatures = tf.train.Features(    feature = {
"favorite_books": tf.train.Feature(bytes_list = favorite_books_bytelist), "hours": tf.train.Feature(float_list = hours_floatlist), "age": tf.train.Feature(int64_list = age_int64list), }) print(features)

三,构建一个tf.train.Example对象

# tf.train.Example(使用features构建Example)example = tf.train.Example(features=features)print(example)# 将tf.train.Example对象序列化.serialized_example = example.SerializeToString()print(serialized_example)

四,将tf.train.Example对象存入文件中,生成的tf_record文件

output_dir = 'tfrecord_basic'if not os.path.exists(output_dir):    os.mkdir(output_dir)filename = "test.tfrecords"filename_fullpath = os.path.join(output_dir, filename)# 将序列化后的tf.train.Example对象存入文件中.with tf.io.TFRecordWriter(filename_fullpath) as writer:    for i in range(3):# 写入三次        writer.write(serialized_example)

五,使用tf.data的API读取tf_record文件

dataset = tf.data.TFRecordDataset([filename_fullpath])for serialized_example_tensor in dataset:    print(serialized_example_tensor)

六,将tf.train.Example对象反序列化

# 定义每个feature的数据类型expected_features = {
"favorite_books": tf.io.VarLenFeature(dtype = tf.string), # VarLenFeature代表变长的数据类型 "hours": tf.io.VarLenFeature(dtype = tf.float32), "age": tf.io.FixedLenFeature([], dtype = tf.int64), # FixedLenFeature代表定长的数据类型}# 解析TFRecord数据集dataset = tf.data.TFRecordDataset([filename_fullpath])for serialized_example_tensor in dataset: example = tf.io.parse_single_example( serialized_example_tensor, expected_features) # print(example) books = tf.sparse.to_dense(example["favorite_books"],default_value=b"") for book in books: print(book.numpy().decode("UTF-8"))

七,将tf.train.Example对象存入压缩文件中,生成的tf_record压缩文件

filename_fullpath_zip = filename_fullpath + '.zip'options = tf.io.TFRecordOptions(compression_type = "GZIP")with tf.io.TFRecordWriter(filename_fullpath_zip, options) as writer:    for i in range(3):        writer.write(serialized_example)

八,使用tf.data的API读取tf_record压缩文件

dataset_zip = tf.data.TFRecordDataset([filename_fullpath_zip],compression_type= "GZIP")for serialized_example_tensor in dataset_zip:    example = tf.io.parse_single_example(        serialized_example_tensor,        expected_features)    books = tf.sparse.to_dense(example["favorite_books"],default_value=b"")    for book in books:        print(book.numpy().decode("UTF-8"))

转载地址:http://nvili.baihongyu.com/

你可能感兴趣的文章
小米启动安心服务月 手机家电产品可免费清洁保养
查看>>
刘作虎:一加新品将全系支持 5G
查看>>
滴滴顺风车上线新功能,特殊时期便捷出行
查看>>
不会延期!iPhone 12S预计如期在9月发售:升级三星LTPO屏幕
查看>>
腾讯物联网操作系统TencentOS tiny线上移植大赛,王者机器人、QQ公仔、定制开发板等礼品等你来拿 !
查看>>
为云而生,腾讯云服务器操作系统TencentOS内核正式开源
查看>>
腾讯汤道生:开源已成为许多技术驱动型产业重要的创新推动力
查看>>
微信小程序多端框架 kbone 开源
查看>>
视频质量评估算法 DVQA 正式开源
查看>>
腾讯优图开源视频动作检测算法DBG,打破两项世界纪录
查看>>
在中国提供了60亿次服务的疫情模块向世界开源 腾讯抗疫科技输出海外
查看>>
在中国提供了60亿次服务的疫情模块向世界开源
查看>>
世界卫生组织与腾讯加深合作 新冠肺炎AI自查助手全球开源
查看>>
Hibernate 中get, load 区别
查看>>
java反射详解
查看>>
JPA 注解
查看>>
JQuery 简介
查看>>
Java创建对象的方法
查看>>
Extjs自定义组件
查看>>
TreeGrid 异步加载节点
查看>>