TensorFlow 저장 / 파일에서 그래프로드


98

지금까지 수집 한 내용에서 TensorFlow 그래프를 파일에 덤핑 한 다음 다른 프로그램에로드하는 방법에는 여러 가지가 있지만 작동 방식에 대한 명확한 예제 / 정보를 찾을 수 없었습니다. 내가 이미 알고있는 것은 이것이다 :

  1. a를 사용하여 모델의 변수를 체크 포인트 파일 (.ckpt)에 저장 tf.train.Saver()하고 나중에 복원 ( source )
  2. 모델을 .pb 파일에 저장하고 tf.train.write_graph()tf.import_graph_def()( 소스 ) 를 사용하여 다시로드합니다.
  3. .pb 파일에서 모델을로드하고 다시 학습 한 다음 Bazel을 사용하여 새 .pb 파일에 덤프합니다 ( 소스 ).
  4. 그래프를 고정하여 그래프와 가중치를 함께 저장합니다 ( 소스 ).
  5. 사용 as_graph_def()모델을 저장 및 무게 / 변수 (상수로 매핑 소스 )

그러나 이러한 다른 방법에 대한 몇 가지 질문을 해결할 수 없었습니다.

  1. 체크 포인트 파일과 관련하여 모델의 훈련 된 가중치 만 저장합니까? 체크 포인트 파일을 새 프로그램에로드하여 모델을 실행하는 데 사용할 수 있습니까, 아니면 단순히 특정 시간 / 단계에서 모델의 가중치를 저장하는 방법으로 사용됩니까?
  2. 와 관련 tf.train.write_graph()하여 가중치 / 변수도 저장됩니까?
  3. Bazel과 관련하여 재교육을 위해 .pb 파일로만 저장 /로드 할 수 있습니까? 그래프를 .pb로 덤프하는 간단한 Bazel 명령이 있습니까?
  4. 고정과 관련하여 고정 된 그래프는 tf.import_graph_def()? 를 사용하여로드 할 수 있습니다 .
  5. TensorFlow 용 Android 데모는 .pb 파일에서 Google의 Inception 모델로로드됩니다. 내 자신의 .pb 파일을 대체하려면 어떻게해야합니까? 네이티브 코드 / 메서드를 변경해야합니까?
  6. 일반적으로이 모든 방법의 차이점은 정확히 무엇입니까? 또는 더 광범위하게 as_graph_def()/.ckpt/.pb 의 차이점은 무엇입니까?

요컨대, 제가 찾고있는 것은 그래프 (다양한 연산 등)와 가중치 / 변수를 파일에 저장하는 방법입니다. 그러면 그래프와 가중치를 다른 프로그램에로드하는 데 사용할 수 있습니다. , 사용을 위해 (반드시 계속 / 재교육하는 것은 아님).

이 주제에 대한 문서는 그다지 간단하지 않으므로 답변 / 정보를 보내 주시면 감사하겠습니다.


2
최신 / 가장 완전한 API는 메타 그래프로, 세 가지를 한 번에 모두 저장할 수있는 방법을 제공합니다-1) 그래프 2) 매개 변수 값 3) 컬렉션 : tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html을
야로 슬라브 Bulatov

답변:


80

TensorFlow에서 모델을 저장하는 문제에 접근하는 방법에는 여러 가지가 있으며, 이로 인해 약간 혼란 스러울 수 있습니다. 각 하위 질문을 차례로 수행 :

  1. 체크 포인트 파일 (예 : 객체 를 호출 saver.save()하여 생성됨 tf.train.Saver)에는 가중치와 동일한 프로그램에 정의 된 다른 변수 만 포함됩니다. 다른 프로그램에서 사용하려면 관련 그래프 구조를 다시 만들어야합니다 (예 : 코드를 실행하여 다시 빌드하거나를 호출하여 tf.import_graph_def()). 그러면 TensorFlow에 해당 가중치로 수행 할 작업을 알려줍니다. 또한를 호출 하면 그래프와 체크 포인트의 가중치를 해당 그래프와 연결하는 방법에 대한 세부 정보가 포함 saver.save()된 파일이 생성 MetaGraphDef됩니다. 자세한 내용 은 튜토리얼 을 참조하십시오.

  2. tf.train.write_graph()그래프 구조 만 작성합니다. 가중치가 아닙니다.

  3. Bazel은 TensorFlow 그래프를 읽거나 쓰는 것과 관련이 없습니다. (아마도 귀하의 질문을 오해하고 있습니다. 의견을 통해 자유롭게 질문하십시오.)

  4. 고정 된 그래프는 tf.import_graph_def(). 이 경우 가중치는 (일반적으로) 그래프에 포함되므로 별도의 체크 포인트를로드 할 필요가 없습니다.

  5. 주요 변경 사항은 모델에 공급되는 텐서의 이름과 모델에서 가져온 텐서의 이름을 업데이트하는 것입니다. TensorFlow Android 데모에서 이는 에 전달 되는 inputNameoutputName문자열에 해당합니다 TensorFlowClassifier.initializeTensorFlow().

  6. GraphDef일반적으로 교육 과정을 변경하지 않는 프로그램 구조입니다. 체크 포인트는 일반적으로 교육 프로세스의 모든 단계에서 변경되는 교육 프로세스 상태의 스냅 샷입니다. 결과적으로 TensorFlow는 이러한 유형의 데이터에 대해 서로 다른 저장 형식을 사용하고 저수준 API는 데이터를 저장하고로드하는 다양한 방법을 제공합니다. 같은과 같은 높은 수준의 도서관, MetaGraphDef도서관, Kerasskflow 이러한 메커니즘에 빌드 저장하고 전체 모델을 복원하는 편리한 방법을 제공합니다.


이것은 저장된 그래프를로드 한 다음 실행할 수 있다고 C ++ API 문서가 거짓말을 한다는 것을 의미합니까 tf.train.write_graph()?
mnicky

2
C ++ API 문서는 거짓말을하지 않지만 몇 가지 세부 정보가 누락되었습니다. 가장 중요한 세부 사항은에 GraphDef의해 저장되는 tf.train.write_graph()것 외에도 그래프를 실행할 때 공급하고 가져 오려는 텐서의 이름도 기억해야한다는 것입니다 (위의 항목 5).
mrry

@mrry : Tensorflows DeepDream 예제를 사용해 보았습니다. 하지만 pb 형식의 사전 훈련 된 모델이 필요한 것 같습니다! Cifar10 예제를 실행했지만 체크 포인트 만 생성합니다! pb 파일 등을 찾을 수 없습니다! 내 체크 포인트를 deepdream 예제에서 사용하는 pb 형식으로 어떻게 변환 할 수 있습니까?
Rika

2
@ Coderx7 체크 포인트는 가중치와 변수 만 포함하고 그래프 구조에 대해 전혀 알지 못하기 때문에 .ckpt를 .pb로 변환 할 수 없다고 생각합니다
davidivad

1
.pb 파일을로드 한 다음 실행하는 간단한 코드가 있습니까?
홍콩

1

다음 코드를 시도해 볼 수 있습니다.

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.