확률 적 경사 하강은 어떻게 표준 경사 하강에 비해 시간을 절약 할 수 있습니까?


15

표준 그라디언트 디센트는 전체 교육 데이터 세트에 대한 그라디언트를 계산합니다.

for i in range(nb_epochs):
  params_grad = evaluate_gradient(loss_function, data, params)
  params = params - learning_rate * params_grad

사전 정의 된 에포크 수에 대해 먼저 매개 변수 벡터 매개 변수를 사용하여 전체 데이터 세트에 대한 손실 함수의 기울기 벡터 weights_grad를 계산합니다.

대조적으로 확률 적 그라디언트 디센트는 각 훈련 예 x (i) 및 라벨 y (i)에 대한 파라미터 업데이트를 수행합니다.

for i in range(nb_epochs):
  np.random.shuffle(data)
  for example in data:
    params_grad = evaluate_gradient(loss_function, example, params)
    params = params - learning_rate * params_grad

SGD는 훨씬 빠릅니다. 그러나 여전히 모든 데이터 포인트에 루프가있는 경우 훨씬 더 빠를 수있는 방법을 이해하지 못합니다. GD에서 그래디언트 계산이 각 데이터 포인트에 대한 GD 계산보다 훨씬 느립니까?

코드는 여기 에서 온다 .


1
두 번째 경우 전체 데이터 세트를 근사하기 위해 작은 배치를 사용합니다. 이것은 일반적으로 꽤 잘 작동합니다. 혼란스러운 부분은 아마도 두 시대 모두 에포크의 수가 같다고 생각할 것입니다. 그러나 케이스 2에 많은 에포크가 필요하지 않을 것입니다. "하이 파라미터"는이 두 가지 방법에서 다를 것입니다 : GD nb_epochs! = SGD nb_epochs. 인수의 목적으로 GD nb_epochs = SGD 예제 * nb_epochs를 사용하여 총 루프 수는 동일하지만 그래디언트 계산은 SGD에서 훨씬 빠릅니다.
니마 무사 비

이력서에 대한이 답변 은 좋고 관련이 있습니다.
바브

답변:


23

짧은 답변:

  • 많은 빅 데이터 설정 (예 : 수백만 개의 데이터 포인트)에서 모든 데이터 포인트를 합산해야하므로 비용 또는 그라디언트를 계산하는 데 시간이 오래 걸립니다.
  • 주어진 반복에서 비용을 줄이기 위해 정확한 기울기를 가질 필요없습니다 . 그래디언트의 근사치가 정상적으로 작동합니다.
  • SGcha (Stochastic Gradient Decent)는 하나의 데이터 포인트 만 사용하여 그라디언트와 유사합니다. 따라서 그래디언트를 평가하면 모든 데이터를 합산하는 것과 비교하여 많은 시간이 절약됩니다.
  • "합리적인"반복 횟수 (이 수는 수만 개일 수 있고 수백만 개일 수있는 데이터 포인트 수보다 훨씬 적을 수 있음)를 사용하면 확률 적 그라디언트가 적절한 솔루션을 얻을 수 있습니다.

긴 대답 :

Andrew NG의 기계 학습 과정 과정을 따릅니다. 익숙하지 않은 경우 여기 에서 강의 시리즈를 검토 할 수 있습니다 .

제곱 손실에 대한 회귀를 가정하자, 비용 함수는

제이(θ)=12미디엄나는=1미디엄(hθ(엑스(나는))와이(나는))2

그라디언트는

제이(θ)θ=1미디엄나는=1미디엄(hθ(엑스(나는))와이(나는))엑스(나는)

그래디언트 디센트 (GD)의 경우

θ이자형=θ영형α1미디엄나는=1미디엄(hθ(엑스(나는))와이(나는))엑스(나는)

1/미디엄엑스(나는),와이(나는) 와서 시간을 절약하십시오.

θ이자형=θ영형α(hθ(엑스(나는))와이(나는))엑스(나는)

시간을 절약 할 수있는 이유는 다음과 같습니다.

10 억 개의 데이터 포인트가 있다고 가정하십시오.

  • GD에서 매개 변수를 한 번 업데이트하려면 (정확한) 그래디언트가 필요합니다. 이를 위해서는 1 개의 업데이트를 수행하기 위해이 10 억 개의 데이터 포인트를 요약해야합니다.

  • SGD에서 우리는 그것을 얻는 것을 정확한 기울기 대신 근사 기울기 . 근사치는 하나의 데이터 포인트 (또는 미니 배치라고하는 여러 데이터 포인트)에서 나옵니다. 따라서 SGD에서 매개 변수를 매우 빠르게 업데이트 할 수 있습니다. 또한 모든 데이터 (하나의 에포크라고 함)를 "반복"하면 실제로 10 억 개의 업데이트가 있습니다.

트릭은 SGD에서 10 억 회 반복 / 업데이트 할 필요는 없지만 1 백만 회 정도로 훨씬 적은 반복 / 업데이트가 필요하며 사용하기에 "충분한"모델을 갖게된다는 것입니다.


아이디어를 시연하는 코드를 작성 중입니다. 먼저 선형 방정식을 정규 방정식으로 풀고 SGD로 풀기합니다. 그런 다음 결과를 매개 변수 값과 최종 목적 함수 값으로 비교합니다. 나중에 시각화하기 위해 튜닝 할 두 개의 매개 변수가 있습니다.

set.seed(0);n_data=1e3;n_feature=2;
A=matrix(runif(n_data*n_feature),ncol=n_feature)
b=runif(n_data)
res1=solve(t(A) %*% A, t(A) %*% b)

sq_loss<-function(A,b,x){
  e=A %*% x -b
  v=crossprod(e)
  return(v[1])
}

sq_loss_gr_approx<-function(A,b,x){
  # note, in GD, we need to sum over all data
  # here i is just one random index sample
  i=sample(1:n_data, 1)
  gr=2*(crossprod(A[i,],x)-b[i])*A[i,]
  return(gr)
}

x=runif(n_feature)
alpha=0.01
N_iter=300
loss=rep(0,N_iter)

for (i in 1:N_iter){
  x=x-alpha*sq_loss_gr_approx(A,b,x)
  loss[i]=sq_loss(A,b,x)
}

결과 :

as.vector(res1)
[1] 0.4368427 0.3991028
x
[1] 0.3580121 0.4782659

124.1343123.0355 이며 매우 가깝습니다.

반복에 대한 비용 함수 값은 다음과 같습니다. 손실을 효과적으로 줄일 수 있다는 것을 알 수 있습니다. 이는 데이터의 하위 집합을 사용하여 그래디언트를 근사화하고 "충분한"결과를 얻을 수 있습니다.

여기에 이미지 설명을 입력하십시오

여기에 이미지 설명을 입력하십시오

1000sq_loss_gr_approx3001000


"속도"에 대한 논쟁은 지역 최적에 수렴하기 위해 얼마나 많은 작업 / 반복이 필요한지에 대한 것이라고 생각 했습니까? (그리고 확률 적 구배 하강은 더 나은 최적화 로 수렴하는 경향이 있습니다.)
GeoMatt22

내가 이해하는 한, 파이썬 코드에서 "data"-변수는 동일합니다. 미니 배치 그라디언트 괜찮은-코드는 SDG와 다릅니다 (정확히 데이터의 작은 부분 만 사용합니다). 또한 제공 한 설명에서 SDG에서 합계를 제거하더라도 각 데이터 포인트에 대한 업데이트를 계산합니다. 각 데이터 포인트를 반복하는 동안 매개 변수를 업데이트하는 것이 모든 데이터 포인트를 한 번에 합산하는 것보다 빠른 방법을 여전히 이해하지 못합니다.
Alina

@ GeoMatt22 내가 제공 한 링크에서 "SGD는 계속해서 오버 슈팅을 계속하기 때문에 궁극적으로 정확한 최소 수렴을 복잡하게합니다." 그것은 더 나은 최적화로 수렴하지 않는다는 것을 의미합니다. 아니면 내가 잘못 했습니까?
Alina

@Tonja 저는 전문가는 아니지만, 예를 들어 딥 러닝에 관한 영향력있는 논문은 확률 적 경사 하강에 대한 "보다 빠르고 안정적인 교육"주장을 제시합니다. "원시"버전은 사용하지 않지만 다양한 곡률 추정을 사용하여 (좌표 의존) 학습 속도를 설정합니다.
GeoMatt22

1
@Tonja, 그렇습니다. 그래디언트의 "약한"근사치가 작동합니다. 비슷한 아이디어 인 "그라데이션 부스팅"을 확인할 수 있습니다. 반면에, 아이디어를 시연하기 위해 몇 가지 코드를 작성하고 있습니다. 준비가되면 게시하겠습니다.
Haitao Du
당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.