Tensorflow SavedModel에서 사용 된 모든 작업을 나열하는 방법은 무엇입니까?


10

tensorflow.saved_model.saveSavedModel 형식 의 함수를 사용하여 모델을 저장하면 나중에이 모델에 사용 된 Tensorflow Ops를 검색 할 수 있습니다. 모델을 복원 할 수 있으므로 이러한 작업은 그래프에 저장되며 추측은 saved_model.pb파일에 있습니다. 이 protobuf (전체 모델이 아님)를로드하면 protobuf의 라이브러리 부분에이 목록이 나열되지만 지금은 실험적인 기능으로 문서화 및 태그 지정되지 않았습니다. Tensorflow 1.x에서 생성 된 모델에는이 부분이 없습니다.

그렇다면 저장된 모델 형식의 모델에서 사용 된 작업 ( MatchingFiles또는 유사 WriteFile) 목록을 검색하는 빠르고 안정적인 방법은 무엇 입니까?

지금처럼 전체를 얼릴 수 있습니다 tensorflowjs-converter. 또한 지원되는 작업을 확인합니다. LSTM이 모델에있는 경우 현재 작동하지 않습니다 ( 여기 참조) . 작전이 확실히 있기 때문에 더 좋은 방법이 있습니까?

예제 모델 :

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

이 경우 최소한 다음을 포함하여 모든 Ops를 출력 할 것으로 예상됩니다.


1
무엇, 당신이 원하는 것을 정확히 말해 어렵다 saved_model.pb, 그것이다 tf.GraphDef, 또는 SavedModelprotobuf 메시지를? 가 tf.GraphDef호출 된 경우를 gd사용하여 사용 된 작업 목록을 얻을 수 있습니다 sorted(set(n.op for n in gd.node)). 로드 된 모델이 있으면 할 수 있습니다 sorted(set(op.type for op in tf.get_default_graph().get_operations())). 이 경우 에서 SavedModel가져옵니다 tf.GraphDef(예 :) saved_model.meta_graphs[0].graph_def.
jdehesa

저장된 SavedModel에서 op를 검색하고 싶습니다. 실제로, 당신이 설명하는 마지막 옵션. saved_model마지막 예에서 변수 는 무엇입니까 ? tf.saved_model.load('/path/to/model')saved_model.pb 파일의 protobuf 결과 또는로드
sampers

답변:


1

경우 saved_model.pbA는 SavedModelprotobuf 메시지가, 당신은 거기에서 직접 작업을 얻을. 다음과 같이 모델을 생성한다고 가정 해 봅시다.

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

이제 다음과 같이 해당 모델에서 사용되는 작업을 찾을 수 있습니다.

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin

나는 이런 식으로 뭔가를 시도했지만 불행하게도 이것은 내가 그것을하지 기대하지 않는 무엇을 : 말 나는이 작업을 수행하는 모델을 가지고 : input_scalar = tf.reshape(file_name, []) output = tf.io.read_file(input_scalar) return tf.stack([output], name='content')그런 다음 ReadFile을 연산 나와있는 여기 가에 있지만 인쇄되지 않습니다.
sampers

1
@ sampers 나는 당신이 제안한 것과 같은 예를 들어 답을 편집했습니다. ReadFile출력 에서 작업을 얻습니다 . 실제 사례에서 저장된 모델의 입력과 출력 사이에 해당 작업이 없을 수 있습니까? 이 경우 정리 될 수 있다고 생각합니다.
jdehesa

실제로 주어진 모델에서 작동합니다. 불행히도 tf2로 만든 모듈은 그렇지 않습니다. 이전 주석에 나열된 호출을 포함하여 file_name인수 @tf.function주석 이있는 1 함수로 tf.Module을 작성 하면 다음 목록이 표시됩니다.Const, NoOp, PartitionedCall, Placeholder, StatefulPartitionedCall
sampers

내 질문에 모델을 추가
sampers

@ sampers 내 답변을 업데이트했습니다. 나는 TF 1.x를 사용하기 전에 TF 2.x에서 그래프 정의 객체의 변경 사항에 익숙하지 않았다. 이제 답변이 저장된 모델의 모든 것을 포함한다고 생각한다. 필자가 작성한 Python 함수 saved_model.meta_graphs[0].graph_def.library.function[0]( node_def해당 함수 객체 내의 컬렉션)에 해당하는 작업이 있다고 생각합니다 .
jdehesa
당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.