<div id="content_views" class="markdown_views"> <!-- flowchart 箭頭圖標 勿刪 --> <svg xmlns="http://www.w3.org/2000/svg" style="display: none;"><path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path></svg> <h1 id="tensorflow經常使用函數筆記"><a name="t0"></a>Tensorflow經常使用函數筆記</h1>python
<hr>web
<h2 id="tfconcat"><a name="t1"></a><a href="https://www.tensorflow.org/versions/r1.1/api_docs/python/tf/concat" rel="nofollow" target="_blank">tf.concat</a></h2>api
<p>把一組向量從某一維上拼接起來,很向numpy中的Concatenate,官網例子:</p>數組
<pre class="prettyprint" name="code"><code class="hljs lua has-numbering">t1 = <span class="hljs-string">[[1, 2, 3], [4, 5, 6]]</span> t2 = <span class="hljs-string">[[7, 8, 9], [10, 11, 12]]</span> tf.concat([t1, t2], <span class="hljs-number">0</span>) ==> <span class="hljs-string">[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]</span> tf.concat([t1, t2], <span class="hljs-number">1</span>) ==> <span class="hljs-string">[[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]</span> # tensor t3 with shape [<span class="hljs-number">2</span>, <span class="hljs-number">3</span>] # tensor t4 with shape [<span class="hljs-number">2</span>, <span class="hljs-number">3</span>] tf.shape(tf.concat([t3, t4], <span class="hljs-number">0</span>)) ==> [<span class="hljs-number">4</span>, <span class="hljs-number">3</span>] tf.shape(tf.concat([t3, t4], <span class="hljs-number">1</span>)) ==> [<span class="hljs-number">2</span>, <span class="hljs-number">6</span>]</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li></ul></pre>markdown
<p>其實,若是是list類型的話也是能夠的,只要是形似Tensor,最後tf.concat返回的仍是Tensor類型</p>svg
<h2 id="tfgather"><a name="t2"></a><a href="https://www.tensorflow.org/versions/r1.1/api_docs/python/tf/gather" rel="nofollow" target="_blank">tf.gather</a></h2>函數
<p>相似於數組的索引,能夠把向量中某些索引值提取出來,獲得新的向量,適用於要提取的索引爲不連續的狀況。這個函數彷佛只適合在一維的狀況下使用。</p>ui
<pre class="prettyprint" name="code"><code class="hljs lua has-numbering">import tensorflow as tf a = tf.Variable(<span class="hljs-string">[[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]]</span>) index_a = tf.Variable([<span class="hljs-number">0</span>,<span class="hljs-number">2</span>]) b = tf.Variable([<span class="hljs-number">1</span>,<span class="hljs-number">2</span>,<span class="hljs-number">3</span>,<span class="hljs-number">4</span>,<span class="hljs-number">5</span>,<span class="hljs-number">6</span>,<span class="hljs-number">7</span>,<span class="hljs-number">8</span>,<span class="hljs-number">9</span>,<span class="hljs-number">10</span>]) index_b = tf.Variable([<span class="hljs-number">2</span>,<span class="hljs-number">4</span>,<span class="hljs-number">6</span>,<span class="hljs-number">8</span>]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) <span class="hljs-built_in">print</span>(sess.run(tf.gather(a, index_a))) <span class="hljs-built_in">print</span>(sess.run(tf.gather(b, index_b))) # <span class="hljs-string">[[ 1 2 3 4 5] # [11 12 13 14 15]]</span> # [<span class="hljs-number">3</span> <span class="hljs-number">5</span> <span class="hljs-number">7</span> <span class="hljs-number">9</span>] </code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li><li style="color: rgb(153, 153, 153);">11</li><li style="color: rgb(153, 153, 153);">12</li><li style="color: rgb(153, 153, 153);">13</li><li style="color: rgb(153, 153, 153);">14</li><li style="color: rgb(153, 153, 153);">15</li><li style="color: rgb(153, 153, 153);">16</li><li style="color: rgb(153, 153, 153);">17</li><li style="color: rgb(153, 153, 153);">18</li></ul></pre>lua
<h2 id="tfgathernd"><a name="t3"></a><a href="https://www.tensorflow.org/versions/r1.1/api_docs/python/tf/gather_nd" rel="nofollow" target="_blank">tf.gather_nd</a></h2>spa
<p>同上,但容許在多維上進行索引,例子只展現了一種很簡單的用法,更復雜的用法可見官網。</p>
<pre class="prettyprint" name="code"><code class="hljs lua has-numbering">import tensorflow as tf a = tf.Variable(<span class="hljs-string">[[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]]</span>) index_a = tf.Variable(<span class="hljs-string">[[0,2], [0,4], [2,2]]</span>) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) <span class="hljs-built_in">print</span>(sess.run(tf.gather_nd(a, index_a))) # [ <span class="hljs-number">3</span> <span class="hljs-number">5</span> <span class="hljs-number">13</span>]</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li></ul></pre>
<h2 id="tfgreater"><a name="t4"></a><a href="https://www.tensorflow.org/versions/r1.1/api_docs/python/tf/greater" rel="nofollow" target="_blank">tf.greater</a></h2>
<p>判斷函數。首先張量x和張量y的尺寸要相同,輸出的tf.greater(x, y)也是一個和x,y尺寸相同的張量。若是x的某個元素比y中對應位置的元素大,則tf.greater(x, y)對應位置返回True,不然返回False。與此相似的函數還有<a href="https://www.tensorflow.org/versions/r1.1/api_docs/python/tf/greater_equal" rel="nofollow" target="_blank">tf.greater_equal</a>。</p>
<pre class="prettyprint" name="code"><code class="hljs lua has-numbering">import tensorflow as tf x = tf.Variable(<span class="hljs-string">[[1,2,3], [6,7,8], [11,12,13]]</span>) y = tf.Variable(<span class="hljs-string">[[0,1,2], [5,6,7], [10,11,12]]</span>) x1 = tf.Variable(<span class="hljs-string">[[1,2,3], [6,7,8], [11,12,13]]</span>) y1 = tf.Variable(<span class="hljs-string">[[10,1,2], [15,6,7], [10,21,12]]</span>) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) <span class="hljs-built_in">print</span>(sess.run(tf.greater(x, y))) <span class="hljs-built_in">print</span>(sess.run(tf.greater(x1, y1))) # <span class="hljs-string">[[ True True True] # [ True True True] # [ True True True]]</span> # <span class="hljs-string">[[False True True] # [False True True] # [ True False True]]</span> </code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li><li style="color: rgb(153, 153, 153);">11</li><li style="color: rgb(153, 153, 153);">12</li><li style="color: rgb(153, 153, 153);">13</li><li style="color: rgb(153, 153, 153);">14</li><li style="color: rgb(153, 153, 153);">15</li><li style="color: rgb(153, 153, 153);">16</li><li style="color: rgb(153, 153, 153);">17</li><li style="color: rgb(153, 153, 153);">18</li><li style="color: rgb(153, 153, 153);">19</li><li style="color: rgb(153, 153, 153);">20</li><li style="color: rgb(153, 153, 153);">21</li></ul></pre>
<h2 id="tfcast"><a name="t5"></a><a href="https://www.tensorflow.org/versions/r1.1/api_docs/python/tf/cast" rel="nofollow" target="_blank">tf.cast</a></h2>
<p>轉換數據類型。</p>
<pre class="prettyprint" name="code"><code class="hljs livecodeserver has-numbering"><span class="hljs-operator">a</span> = tf.<span class="hljs-built_in">constant</span>([<span class="hljs-number">0</span>, <span class="hljs-number">2</span>, <span class="hljs-number">0</span>, <span class="hljs-number">4</span>, <span class="hljs-number">2</span>, <span class="hljs-number">2</span>], dtype=<span class="hljs-string">'int32'</span>) print(<span class="hljs-operator">a</span>) <span class="hljs-comment"># <tf.Tensor 'Const_1:0' shape=(6,) dtype=int32></span> b = tf.cast(<span class="hljs-operator">a</span>, <span class="hljs-string">'float32'</span>) print(b) <span class="hljs-comment"># <tf.Tensor 'Cast:0' shape=(6,) dtype=float32></span></code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li></ul></pre>
<h2 id="tfexpanddims-tfsqueeze"><a name="t6"></a><a href="https://www.tensorflow.org/api_docs/python/tf/expand_dims" rel="nofollow" target="_blank">tf.expand_dims</a> & <a href="https://www.tensorflow.org/api_docs/python/tf/squeeze" rel="nofollow" target="_blank">tf.squeeze</a></h2>
<p>增長 / 壓縮張量的維度。</p>
<pre class="prettyprint" name="code"><code class="hljs go has-numbering">a = tf.constant(<span class="hljs-number">[0</span>,<span class="hljs-number"> 2</span>,<span class="hljs-number"> 0</span>,<span class="hljs-number"> 4</span>,<span class="hljs-number"> 2</span>,<span class="hljs-number"> 2</span>], dtype=<span class="hljs-string">'int32'</span>) <span class="hljs-built_in">print</span>(a) # <tf.Tensor <span class="hljs-string">'Const_1:0'</span> shape=<span class="hljs-number">(6</span>,) dtype=<span class="hljs-typename">int32</span>> b = tf.expand_dims(a,<span class="hljs-number"> 0</span>) <span class="hljs-built_in">print</span>(b) # <tf.Tensor <span class="hljs-string">'ExpandDims:0'</span> shape=<span class="hljs-number">(1</span>,<span class="hljs-number"> 6</span>) dtype=<span class="hljs-typename">int32</span>> <span class="hljs-built_in">print</span>(tf.squeeze(b,<span class="hljs-number"> 0</span>)) # <tf.Tensor <span class="hljs-string">'Squeeze:0'</span> shape=<span class="hljs-number">(6</span>,) dtype=<span class="hljs-typename">int32</span>></code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li></ul></pre> </div>