本文共 3736 字,大约阅读时间需要 12 分钟。
本文将介绍:
#!/usr/bin/env python3# -*- coding: utf-8 -*-import matplotlib as mplimport matplotlib.pyplot as pltimport numpy as npimport sklearnimport pandas as pdimport osimport sysimport timeimport tensorflow as tffrom tensorflow import keras# 打印使用的python库的版本信息print(tf.__version__)print(sys.version_info)for module in mpl, np, pd, sklearn, tf, keras: print(module.__name__, module.__version__) # tfrecord 是一种文件格式# -> tf.train.Example可以是一个样本或者一组# 每个Example-> tf.train.Features -> {"key": tf.train.Feature}# 每个Feature有不同的格式-> tf.train.Feature -> tf.train.ByteList/FloatList/Int64List# 1,构建一个ByteList,FloatList,Int64List的对象# tf.train.ByteListfavorite_books = [name.encode('utf-8') for name in ["machine learning", "cc150"]]favorite_books_bytelist = tf.train.BytesList(value = favorite_books)print(favorite_books_bytelist,type(favorite_books_bytelist))# tf.train.FloatListhours_floatlist = tf.train.FloatList(value = [15.5, 9.5, 7.0, 8.0])print(hours_floatlist,type(hours_floatlist))# tf.train.Int64Listage_int64list = tf.train.Int64List(value = [42])print(age_int64list,type(age_int64list))
# tf.train.Featuresfeatures = tf.train.Features( feature = { "favorite_books": tf.train.Feature(bytes_list = favorite_books_bytelist), "hours": tf.train.Feature(float_list = hours_floatlist), "age": tf.train.Feature(int64_list = age_int64list), }) print(features)
# tf.train.Example(使用features构建Example)example = tf.train.Example(features=features)print(example)# 将tf.train.Example对象序列化.serialized_example = example.SerializeToString()print(serialized_example)
output_dir = 'tfrecord_basic'if not os.path.exists(output_dir): os.mkdir(output_dir)filename = "test.tfrecords"filename_fullpath = os.path.join(output_dir, filename)# 将序列化后的tf.train.Example对象存入文件中.with tf.io.TFRecordWriter(filename_fullpath) as writer: for i in range(3):# 写入三次 writer.write(serialized_example)
dataset = tf.data.TFRecordDataset([filename_fullpath])for serialized_example_tensor in dataset: print(serialized_example_tensor)
# 定义每个feature的数据类型expected_features = { "favorite_books": tf.io.VarLenFeature(dtype = tf.string), # VarLenFeature代表变长的数据类型 "hours": tf.io.VarLenFeature(dtype = tf.float32), "age": tf.io.FixedLenFeature([], dtype = tf.int64), # FixedLenFeature代表定长的数据类型}# 解析TFRecord数据集dataset = tf.data.TFRecordDataset([filename_fullpath])for serialized_example_tensor in dataset: example = tf.io.parse_single_example( serialized_example_tensor, expected_features) # print(example) books = tf.sparse.to_dense(example["favorite_books"],default_value=b"") for book in books: print(book.numpy().decode("UTF-8"))
filename_fullpath_zip = filename_fullpath + '.zip'options = tf.io.TFRecordOptions(compression_type = "GZIP")with tf.io.TFRecordWriter(filename_fullpath_zip, options) as writer: for i in range(3): writer.write(serialized_example)
dataset_zip = tf.data.TFRecordDataset([filename_fullpath_zip],compression_type= "GZIP")for serialized_example_tensor in dataset_zip: example = tf.io.parse_single_example( serialized_example_tensor, expected_features) books = tf.sparse.to_dense(example["favorite_books"],default_value=b"") for book in books: print(book.numpy().decode("UTF-8"))
转载地址:http://nvili.baihongyu.com/