在 TensorFlow 中,处理可变长度张量(Variable Length Tensors)通常需要结合 Ragged Tensors 或 Sparse Tensors 使用,因为标准的 Dense Tensor(密集张量)要求所有维度的大小都必须是固定的。

Ragged Tensors

Ragged Tensors 可以用来表示具有不同长度的嵌套序列。以下是一个示例:

import tensorflow as tf

# 创建一个 Ragged Tensor
ragged_tensor = tf.ragged.constant([[1, 2, 3], [4, 5], [6]])

# 打印 Ragged Tensor
print(ragged_tensor)

修改 Ragged Tensor

可以使用 tf.ragged.concat 和其他 Ragged Tensor 特定的操作进行修改。

python复制代码# 添加新子列表
new_row = tf.ragged.constant([[7, 8]])
ragged_tensor = tf.concat([ragged_tensor, new_row], axis=0)

# 修改特定元素
rt = tf.RaggedTensor.from_row_lengths(values=[1, 2, 3, 4, 5, 6, 7, 8], row_lengths=[3, 2, 1, 2])
modified_rt = rt.with_flat_values(tf.tensor_scatter_nd_update(rt.flat_values, [[3]], [10]))

print(modified_rt)

Sparse Tensors

Sparse Tensors 是另一种处理可变长度数据的方式,特别是当数据中有大量的零时。以下是一个示例:

import tensorflow as tf

# 创建一个 Sparse Tensor
indices = [[0, 0], [1, 2], [2, 3]]
values = [1, 2, 3]
dense_shape = [3, 4]
sparse_tensor = tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)

# 转换为稠密张量以查看内容
dense_tensor = tf.sparse.to_dense(sparse_tensor)

# 打印 Sparse Tensor
print(dense_tensor)

在这个示例中,sparse_tensor 只存储了三个非零值及其对应的位置。转换为密集张量后的输出如下:

[[1 0 0 0]
 [0 0 2 0]
 [0 0 0 3]]

修改 Sparse Tensor

可以使用 tf.sparse.concat 和 tf.sparse.reorder 进行修改。

python复制代码# 添加新元素
new_indices = tf.constant([[0, 3], [2, 1]], dtype=tf.int64)
new_values = tf.constant([4, 5], dtype=tf.int32)
sparse_tensor = tf.sparse.concat(axis=0, sp_inputs=[sparse_tensor, tf.SparseTensor(new_indices, new_values, dense_shape)])

# 删除特定元素
mask = sparse_tensor.values != 2
sparse_tensor = tf.sparse.retain(sparse_tensor, mask)

print(tf.sparse.to_dense(sparse_tensor))

Ragged Tensors 和 Sparse Tensors 的区别

  • Ragged Tensors: 适用于处理嵌套序列或不规则形状的数据结构。它们允许不同的行有不同数量的元素。
  • Sparse Tensors: 适用于大部分元素为零的张量,节省存储空间和计算资源。

Dense Tensor(密集张量)

Dense Tensor 是 TensorFlow 中的标准张量类型,要求所有维度的大小都必须是固定的,并且所有元素都有存储,无论是否为零。这对于大多数传统机器学习和深度学习任务是最常用的。

示例

python复制代码import tensorflow as tf

# 创建一个 Dense Tensor
dense_tensor = tf.constant([[1, 2, 3], [4, 5, 6]])

# 打印 Dense Tensor
print(dense_tensor)

在这个示例中,dense_tensor 是一个普通的二维张量,打印输出如下:

复制代码

修改 Dense Tensor

可以使用 TensorFlow 的内置操作来修改张量,例如 tf.concat、tf.slice 等。

new_row = tf.constant([[7, 8, 9]])
dense_tensor = tf.concat([dense_tensor, new_row], axis=0)

# 修改特定元素
dense_tensor = tf.tensor_scatter_nd_update(dense_tensor, [[0, 1]], [10])

print(dense_tensor)
本站无任何商业行为
个人在线分享 » TensorFlow 张量
E-->