合併與分割

Merge and split

  • tf.concat
  • tf.split
  • tf.stack
  • tf.unstack

concat

  • Statistics ablout scores
    • [class1-4,students,scores]
    • [class5-6,students,scores]
import tensorflow as tf
# 6個班級的學生分數狀況
a = tf.ones([4, 35, 8])
b = tf.ones([2, 35, 8])
c = tf.concat([a, b], axis=0)
c.shape
TensorShape([6, 35, 8])
# 3個學生學生補考
a = tf.ones([4, 32, 8])
b = tf.ones([4, 3, 8])
tf.concat([a, b], axis=1).shape
TensorShape([4, 35, 8])

Along distinct dim/axis

08-合併與分割-axis的區別.jpg

stack: create new dim

  • Statistics about scores
    • School1:[classes,students,scores]
    • School2:[classes,students,scores]
    • [schools,calsses,students,scores]
a = tf.ones([4, 35, 8])
b = tf.ones([4, 35, 8])
a.shape
TensorShape([4, 35, 8])
b.shape
TensorShape([4, 35, 8])
tf.concat([a, b], axis=-1).shape
TensorShape([4, 35, 16])
tf.stack([a, b], axis=0).shape
TensorShape([2, 4, 35, 8])
tf.stack([a, b], axis=3).shape
TensorShape([4, 35, 8, 2])

Dim mismatch

a = tf.ones([4, 35, 8])
b = tf.ones([3, 33, 8])
try:
    tf.concat([a, b], axis=0).shape
except Exception as e:
    print(e)
ConcatOp : Dimensions of inputs should match: shape[0] = [4,35,8] vs. shape[1] = [3,33,8] [Op:ConcatV2] name: concat
# concat保證只有一個維度不相等
b = tf.ones([2, 35, 8])
c = tf.concat([a, b], axis=0)
c.shape
TensorShape([6, 35, 8])
# stack保證全部維度相等
try:
    tf.stack([a, b], axis=0)
except Exception as e:
    print(e)
Shapes of all inputs must match: values[0].shape = [4,35,8] != values[1].shape = [2,35,8] [Op:Pack] name: stack

Unstack

a.shape
TensorShape([4, 35, 8])
b = tf.ones([4, 35, 8])
c = tf.stack([a, b])
c.shape
TensorShape([2, 4, 35, 8])
aa, bb = tf.unstack(c, axis=0)
aa.shape, bb.shape
(TensorShape([4, 35, 8]), TensorShape([4, 35, 8]))
# [2,4,35,8]
res = tf.unstack(c, axis=3)
# 8個[2, 4, 35]的Tensor
res[0].shape, res[1].shape, res[7].shape
(TensorShape([2, 4, 35]), TensorShape([2, 4, 35]), TensorShape([2, 4, 35]))
# [2,4,35,8]
res = tf.unstack(c, axis=2)
# 35個[2, 4, 8]的Tensor
res[0].shape, res[1].shape, res[34].shape
(TensorShape([2, 4, 8]), TensorShape([2, 4, 8]), TensorShape([2, 4, 8]))

Split

  • 相比較unstack靈活性更強
# 8個Tensor,全爲1
res = tf.unstack(c, axis=3)
len(res)
8
# 2個Tensor,一個六、一個2
res = tf.split(c, axis=3, num_or_size_splits=2)
len(res)
2
res[0].shape
TensorShape([2, 4, 35, 4])
res = tf.split(c, axis=3, num_or_size_splits=[2, 2, 4])
res[0].shape, res[1].shape, res[2].shape
(TensorShape([2, 4, 35, 2]),
 TensorShape([2, 4, 35, 2]),
 TensorShape([2, 4, 35, 4]))
相關文章
相關標籤/搜索