주의는 종종 룩업 벡터 를 통해 벡터 세트를 하나의 벡터로 집계하는 방법입니다 . 일반적으로 는 모델에 대한 입력 또는 이전 시간 단계의 숨겨진 상태이거나 숨겨진 상태를 한 수준 아래로 쌓아 올린 것입니다 (적층 LSTM의 경우).viuvi
결과는 현재 시간 단계와 관련된 컨텍스트를 포함하므로 종종 컨텍스트 벡터 라고합니다 .c
이 추가 컨텍스트 벡터 는 RNN / LSTM에도 공급됩니다 (원래 입력과 간단히 연결될 수 있음). 따라서 컨텍스트를 사용하여 예측을 도울 수 있습니다.c
가장 간단한 방법은 확률 벡터 및 여기서 는 모든 이전 의 연결입니다 . 공통 조회 벡터 는 현재 숨겨진 상태 입니다.p=softmax(VTu)c=∑ipiviVviuht
이것에는 많은 변형이 있으며 원하는대로 복잡하게 만들 수 있습니다. 예를 들어, 를 로짓으로 사용하는 대신 대신 선택할 수 있습니다 . 여기서 는 임의의 신경망입니다.vTiuf(vi,u)f
시퀀스-시퀀스 모델에 대한 일반적인주의 메커니즘은 . 여기서 는 인코더의 숨겨진 상태이고 는 현재 숨겨진 상태입니다. 디코더의 상태. 와 두 모두 매개 변수입니다.p=softmax(qTtanh(W1vi+W2ht))vhtqW
주의 아이디어에 다른 변형을 보여주는 일부 논문 :
포인터 네트워크 는 조합 최적화 문제를 해결하기 위해 참조 입력에주의를 기울입니다.
반복 엔티티 네트워크 는 텍스트를 읽는 동안 다른 엔티티 (사람 / 객체)에 대해 별도의 메모리 상태를 유지하고주의를 기울여 올바른 메모리 상태를 업데이트합니다.
변압기 모델은 또한 광범위한 관심을 기울입니다. 주의 집중 공식은 약간 더 일반적이며 키 벡터 도 포함 합니다.주의 가중치 는 실제로 키와 조회 사이에서 계산되며 컨텍스트는 로 구성됩니다 .kipvi
간단한 테스트를 통과했다는 사실을 넘어서는 정확성을 보장 할 수는 없지만 한 가지 형태의주의를 빠르게 구현합니다.
기본 RNN :
def rnn(inputs_split):
bias = tf.get_variable('bias', shape = [hidden_dim, 1])
weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])
hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
for i, input in enumerate(inputs_split):
input = tf.reshape(input, (batch, in_dim, 1))
last_state = hidden_states[-1]
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
hidden_states.append(hidden)
return hidden_states[-1]
새로운 숨겨진 상태가 계산되기 전에 몇 줄만 추가하면됩니다.
if len(hidden_states) > 1:
logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
probs = tf.nn.softmax(logits)
probs = tf.reshape(probs, (batch, -1, 1, 1))
context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
else:
context = tf.zeros_like(last_state)
last_state = tf.concat([last_state, context], axis = 1)
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
전체 코드