GELU 활성화 란 무엇입니까?


18

I가 겪고있는 BERT 종이 사용 겔루 (가우스 오류 선형 단위) 와 같은 상태 방정식 이는

GELU(x)=xP(Xx)=xΦ(x).
0.5x(1+tanh[2/π(x+0.044715x3)])

방정식을 단순화하고 대략적인 방법을 설명해 주시겠습니까?

답변:


19

GELU 기능

\ mathcal {N} (0, 1) , 즉 \ Phi (x)누적 분포를N(0,1) 다음과 같이 확장 할 수 있습니다 . \ text {GELU} (x) : = x {\ Bbb P} (X \ le x) = x \ Phi (x) = 0.5x \ left (1+ \ text {erf} \ left (\ frac {x} {\ sqrt {2}} \ right) \ right)Φ(x)

GELU(x):=xP(Xx)=xΦ(x)=0.5x(1+erf(x2))

이것은 방정식 (또는 관계) 이 아니라 정의 입니다. 저자는이 제안에 대한 몇 가지 정당성을 제공했습니다 (예 : 확률 론적 비유 ). 수학적으로 이것은 정의 일뿐입니다.

GELU의 도표는 다음과 같습니다.

탄 근사

이러한 유형의 수치 근사의 경우 핵심 아이디어는 유사한 기능 (주로 경험을 기반으로 함)을 찾아 매개 변수화 한 다음 원래 함수의 포인트 세트에 맞추는 것입니다.

가 매우 가깝다는 것을 알고erf(x)tanh(x)

그리고 1 차 도함수 의 그것과 일치 에서 인 이면 (또는 더 많은 용어 포함)를 .erf(x2)tanh(2πx)x=02π

tanh(2π(x+ax2+bx3+cx4+dx5))
(xi,erf(xi2))

이 기능을 ( 이 사이트 사용 사이의 20 개 샘플에 맞추 계수는 다음과 같습니다.(1.5,1.5)

설정함으로써, , 로 추정 하였다 . 더 넓은 범위 (샘플이 20 개만 허용)의 샘플이 많을수록 계수 는 용지의 더 가깝 . 마침내 우리는a=c=d=0b0.04495641b0.044715

GELU(x)=xΦ(x)=0.5x(1+erf(x2))0.5x(1+tanh(2π(x+0.044715x3)))

대해 평균 제곱 오차 입니다 .108x[10,10]

첫 번째 파생 상품 간의 관계를 활용하지 않으면 라는 용어 가 와 같이 매개 변수에 포함되었을 것입니다. 은 덜 아름답습니다 (분석적, 수치 적)!2π

0.5x(1+tanh(0.797885x+0.035677x3))

패리티 활용

@BookYourLuck 에서 제안한대로 함수의 패리티를 사용하여 검색하는 다항식의 공간을 제한 할 수 있습니다. 그 때문에,이다 홀수 함수 즉, 인 및 또한, 다항식 함수 홀수 함수 안쪽 는 를 갖기 위해 홀수 여야합니다 ( 홀수 만 있어야erff(x)=f(x)tanhpol(x)tanhx

erf(x)tanh(pol(x))=tanh(pol(x))=tanh(pol(x))erf(x)

이전에, 우리는 심지어 권력에 대한 (거의) 제로 계수와 끝까지 행운이었다 및 , 그러나 일반적으로, 이것은 예를 들어, 같은 용어가 그 낮은 품질의 근사치로 이어질 수 이 단순히 를 선택하는 대신 추가 조건 (짝수 또는 홀수)으로 취소됩니다 .x2x40.23x20x2

S 자형 근사

와 (sigmoid) 사이에도 비슷한 관계가 있습니다. 대한 평균 제곱 오차 .erf(x)2(σ(x)12)104x[10,10]

다음은 데이터 포인트 생성, 함수 피팅 및 평균 제곱 오류 계산을위한 Python 코드입니다.

import math
import numpy as np
import scipy.optimize as optimize


def tahn(xs, a):
    return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]


def sigmoid(xs, a):
    return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]


print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
#       .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])

# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))

sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))

# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()

# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()

print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)

if print_points == 1:
    print(len(xs))
    for x, erf in zip(xs, erfs):
        print(x, erf)

산출:

Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05

2
근사가 필요한 이유는 무엇입니까? 그들은 단지 erf 함수를 사용할 수 없었습니까?
SebiSebi

8

먼저 의 패리티 . 우리는 보여줄 필요가 에 대한 .

Φ(x)=12erfc(x2)=12(1+erf(x2))
erf
erf(x2)tanh(2π(x+ax3))
a0.044715

큰 값의 경우 두 함수 모두 묶입니다 . 작은 경우, 각각의 Taylor 계열은 및 대체하면 및 계수를 동일시하는 것은위한 , 우리는 찾을 수 종이의 가까운x[1,1]x

tanh(x)=xx33+o(x3)
erf(x)=2π(xx33)+o(x3).
tanh(2π(x+ax3))=2π(x+(a23π)x3)+o(x3)
erf(x2)=2π(xx36)+o(x3).
x3
a0.04553992412
0.044715.

당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.