TensorFlow에서 Variable과 get_variable의 차이점


125

내가 아는 Variable한 변수를 만들기위한 기본 작업 get_variable이며 주로 가중치 공유에 사용됩니다.

한편 으로 변수가 필요할 때마다 get_variable원시 Variable연산 대신 사용 을 제안하는 사람들이 있습니다 . 반면에 저는 get_variableTensorFlow의 공식 문서 및 데모에서의 사용을 볼뿐입니다 .

따라서이 두 가지 메커니즘을 올바르게 사용하는 방법에 대한 몇 가지 경험 규칙을 알고 싶습니다. "표준"원칙이 있습니까?


6
get_variable은 새로운 방식이고, Variable은 이전 방식입니다 (영원히 지원 될 수 있음). Lukasz가 말한 것처럼 (PS : 그는 TF에서 많은 변수 이름 범위를 썼습니다)
Yaroslav Bulatov

답변:


90

항상 사용하는 것이 좋습니다 tf.get_variable(...). 예를 들어 다중 GPU 설정 (다중 GPU CIFAR 예제 참조)과 같이 언제든지 변수를 공유해야하는 경우 코드를 리팩토링하는 것이 더 쉬워집니다. 단점이 없습니다.

Pure tf.Variable는 하위 수준입니다. 어떤 시점 tf.get_variable()에는 존재하지 않았으므로 일부 코드는 여전히 저수준 방식을 사용합니다.


5
답변 해주셔서 감사합니다. 그러나 나는 어디에서나 대체하는 방법에 대해 여전히 한 가지 질문 tf.Variabletf.get_variable있습니다. 즉, numpy 배열로 변수를 초기화하고 싶을 때 에서처럼 깨끗하고 효율적인 방법을 찾을 수 없습니다 tf.Variable. 어떻게 해결합니까? 감사.
Lifu Huang

69

tf.Variable는 클래스이며, tf.Variable을 포함하여 작성하는 방법에는 여러 가지가있다 tf.Variable.__init__tf.get_variable.

tf.Variable.__init__: initial_value를 사용 하여 새 변수를 만듭니다 .

W = tf.Variable(<initial-value>, name=<optional-name>)

tf.get_variable: 이러한 매개 변수로 기존 변수를 가져 오거나 새 변수를 만듭니다. 이니셜 라이저를 사용할 수도 있습니다.

W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
       regularizer=None, trainable=True, collections=None)

다음과 같은 이니셜 라이저를 사용하는 것이 매우 유용합니다 xavier_initializer.

W = tf.get_variable("W", shape=[784, 256],
       initializer=tf.contrib.layers.xavier_initializer())

여기에 더 많은 정보가 있습니다 .


예, Variable실제로 __init__. get_variable매우 편리 하기 때문에 내가 본 대부분의 TensorFlow 코드 Variableget_variable. 둘 중에서 선택할 때 고려해야 할 규칙이나 요소가 있습니까? 감사합니다!
Lifu Huang

특정 값을 원할 경우 Variable을 사용하는 것은 간단합니다. x = tf.Variable (3).
성 김

@SungKim은 일반적으로 사용할 때 tf.Variable()잘린 정규 분포에서 임의의 값으로 초기화 할 수 있습니다. 여기 내 예가 w1 = tf.Variable(tf.truncated_normal([5, 50], stddev = 0.01), name = 'w1')있습니다. 이것에 상응하는 것은 무엇입니까? 잘린 법선을 원한다고 어떻게 말합니까? 그냥해야 w1 = tf.get_variable(name = 'w1', shape = [5,50], initializer = tf.truncated_normal, regularizer = tf.nn.l2_loss)하나요?
Euler_Salter

@Euler_Salter : tf.truncated_normal_initializer()원하는 결과를 얻기 위해 사용할 수 있습니다 .
베타

46

하나와 다른 두 가지 주요 차이점을 찾을 수 있습니다.

  1. 첫 번째는 tf.Variable항상 새 변수를 생성하는 반면 그래프에서 지정된 매개 변수 tf.get_variable가있는 기존 변수를 가져 와서 존재하지 않는 경우 새 변수를 생성 한다는 것 입니다.

  2. tf.Variable 초기 값을 지정해야합니다.

tf.get_variable재사용 검사를 수행 하기 위해 함수 가 현재 변수 범위를 이름 앞에 붙인다는 점을 명확히하는 것이 중요 합니다. 예를 들면 :

with tf.variable_scope("one"):
    a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
    b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
    c = tf.get_variable("v", [1]) #c.name == "one/v:0"

with tf.variable_scope("two"):
    d = tf.get_variable("v", [1]) #d.name == "two/v:0"
    e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"

assert(a is c)  #Assertion is true, they refer to the same object.
assert(a is d)  #AssertionError: they are different objects
assert(d is e)  #AssertionError: they are different objects

마지막 주장 오류는 흥미 롭습니다. 같은 범위에서 이름이 같은 두 변수는 같은 변수로 간주됩니다. 당신은 변수의 이름을 테스트한다면 d그리고 e당신은 알게 될 것이다 Tensorflow은 변수의 이름을 변경하는 것이 e:

d.name   #d.name == "two/v:0"
e.name   #e.name == "two/v_1:0"

좋은 예! 과 관련 d.name하여 e.name저는 텐서 그래프 이름 지정 작업에 대한 이 TensorFlow 문서를If the default graph already contained an operation named "answer", the TensorFlow would append "_1", "_2", and so on to the name, in order to make it unique.
보았습니다

2

또 다른 차이점은 하나는 ('variable_store',)컬렉션에 있지만 다른 하나는 그렇지 않다는 것입니다.

소스 코드를 참조하십시오 :

def _get_default_variable_store():
  store = ops.get_collection(_VARSTORE_KEY)
  if store:
    return store[0]
  store = _VariableStore()
  ops.add_to_collection(_VARSTORE_KEY, store)
  return store

설명하겠습니다.

import tensorflow as tf
from tensorflow.python.framework import ops

embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="word_embeddings_1", dtype=tf.float32) 
embedding_2 = tf.get_variable("word_embeddings_2", shape=[30522, 1024])

graph = tf.get_default_graph()
collections = graph.collections

for c in collections:
    stores = ops.get_collection(c)
    print('collection %s: ' % str(c))
    for k, store in enumerate(stores):
        try:
            print('\t%d: %s' % (k, str(store._vars)))
        except:
            print('\t%d: %s' % (k, str(store)))
    print('')

출력 :

collection ('__variable_store',): 0: {'word_embeddings_2': <tf.Variable 'word_embeddings_2:0' shape=(30522, 1024) dtype=float32_ref>}

당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.