DoFn被构造了多少次?

我正在使用apache models python SDK和Dataflow编写一个用于使用TensorFlow模型进行预测的推理管道。我在DoFn中有预测步骤,但我不想每次处理捆绑包时都加载模型,因为这非常昂贵。在文档here中,“如果需要,将在worker上创建参数DoFn的新实例,并在该实例上调用DoFn.Setup方法。这可能是通过反序列化或其他方式。PipelineRunner可以对多个bundle重用DoFn实例。异常终止(通过抛出异常)的DoFn将永远不会被重用。”我注意到如果我像这样写代码

class StatefulGetEmbeddingsDoFn(beam.DoFn):
    def __init__(self, model_dir):
         self.model = None # initialize
         self.model_dir = model_dir

    def process(self, element):
         if not self.model: # load model if model hasn't been loaded yet
             global i
             i += 1
             logging.info('Getting model: {}'.format(i))
             self.model = Model(saved_model_dir=self.model_dir)


         ids, b64 = element
         embeddings = self.model.predict(b64)

         res = [
            {
                'image': _id,
                'embeddings': embedding.tolist()
            } for _id, embedding in zip(ids, embeddings)
         ]
         return res

该模型似乎在每个worker上加载了不止一次(我有一个大约30-40台机器的集群)。有没有办法防止模型被多次加载?我曾期望这个DoFn在每台机器上只构建一次,但从日志中看,似乎不是这样的……

转载请注明出处:http://www.lechuangzk.com/article/20230526/1435546.html