tensorflow 로 checkpoint 파일(.ckpt)과 ckeckpoint state proto 이용하기 : Variable 저장과 재활용

deeplearning | 29 July 2017

Tags | tensorflow python

tensorflow 로 학습시킨 딥러닝 모델을 저장하는 방법중 하나로 Checkpoint 을 이용하는 방법이 있다.

Checkpoint 은 학습된 모델의 Variable 값을 저장하는 파일이다.

Checkpoint 파일을 저장하고 불러옴으로써 학습된 모델을 재사용하고, 지난 학습을 이어서 더 하고 하는 작업들이 가능해진다.

이번 글에서는 Checkpoint 파일을 다루는 tensorflow 모듈에 대해서 알아보고, 능숙하게 딥러닝 모델을 저장하고 불러오는 작업을 수행할 수 있도록 하는 것이 목표이다.



0. 모델 디렉터리 구조



project(root)
ㄴmodel
~ ㄴgraph.py
~ ㄴrunner.py
~ ㄴutils.py
ㄴdata
~ ㄴtrain_data
~ ㄴtest_data
ㄴsaved
~

프로젝트 디렉터리 구조가 이렇다고 가정한다.

#/model/graph.py

import tensorflow as tf

g = tf.Graph()

with g.as_default() :
    v1 = tf.Variable(10, name = "v1")
    v2 = tf.Variable(11, name = "v2")



1. checkpoint 저장



위의 모델 그래프를 학습시키면서, 학습된 모델의 Variable 들을 checkpoint에 저장해본다.

첫번째 training job 의 이름을 train1 이라고 하자.

train1 job 의 결과물은 /saved/train1.ckpt 에 저장할 것이다.

그러기 위해서는 Checkpoint파일을 저장해주는 tf.train.Saver() 클래스를 이용해야한다.

참고 : tensorflow api 공식 doc : tf.train.Saver

#/model/train.py

with tf.Session(graph=g) as sess :

    # Saver instance 를 생성한다.
    # Saver.save(sess, ckpt_path)
    # Saver.restore(sess, ckpt_path)

    saver = tf.train.Saver()

    sess.run(tf.global_variables_initializer())

    # 그래프를 돌리다가 Variable 을 저장하고 싶을 때 Saver.save() 메서드를 사용한다.
    # args : tf.Session, job`s checkpoint file path
    # return : job`s checkpoint file path (String)

    ckpt_path = saver.save(sess, "saved/train1")

    print("job`s ckpt files is save as : ", ckpt_path)
job`s ckpt files is save as :  saved/train1

위의 코드를 수행하고 나면 프로젝트의 /saved 디렉터리에 새로운 파일들이 생성된다.

project(root)
ㄴmodel
~ ㄴgraph.py
~ ㄴrunner.py
~ ㄴutils.py
ㄴdata
~ ㄴtrain_data
~ ㄴtest_data
ㄴsaved
~ ㄴcheckpoint
~ ㄴtrain1.ckpt.data-00000-of-00001
~ ㄴtrain1.ckpt.index
~ ㄴtrain1.ckpt.meta

이들중 job name 인 train1 으로 시작하는 세개의 파일이 train1 job 의 Checkpoint 파일이다.

맨 위의 checkpoint 란 이름의 파일은 조금 이따 이야기한다.



첫번째 training job : train1 의 결과가 만족스럽지 못해서, 모델을 조금 수정해서 다시 training 을 하려고 한다.

이번 job 의 이름은 train2 라고 하자.

train2 job 의 결과물을 /saved/train2.ckpt 에 저장하는데, 이번엔 매 iteration마다 Variables 의 값을 저장하고싶다.

이럴땐 job의 이름을 유지한채로, iteration 별로 Checkpoint file을 별도로 생성할 수 있다.

#/model/train.py(수정함)

with tf.Session(graph=g) as sess :

    # 위와 마찬가지로 Saver 생성

    saver = tf.train.Saver()

    sess.run(tf.global_variables_initializer())
    for step in range(10):

        # 매 step 마다 모델 저장하고 싶다면 save 메서드에 step 인자를 하나 추가한다.
        # args : tf.Session, job`s checkpoint file path, step
        # return : job`s checkpoint file path (String)

        ckpt_path = saver.save(sess, "saved/train2", step)

        print("save ckpt file:" , ckpt_path)
save ckpt file: saved/train2-0
save ckpt file: saved/train2-1
save ckpt file: saved/train2-2
save ckpt file: saved/train2-3
save ckpt file: saved/train2-4
save ckpt file: saved/train2-5
save ckpt file: saved/train2-6
save ckpt file: saved/train2-7
save ckpt file: saved/train2-8
save ckpt file: saved/train2-9

보다시피 train1 job 과는 다르게 job name 뒤에 iteration의 step 이 적혀져서 총 10묶음의 체크포인트가 만들어진다.

project(root)
ㄴmodel
~ ㄴgraph.py
~ ㄴrunner.py
~ ㄴutils.py
ㄴdata
~ ㄴtrain_data
~ ㄴtest_data
ㄴsaved
~ ㄴcheckpoint
~ ㄴtrain1.ckpt.data-00000-of-00001
~ ㄴtrain1.ckpt.index
~ ㄴtrain1.ckpt.meta
~ ㄴtrain2-0.ckpt.data-00000-of-00001
~ ㄴtrain2-0.ckpt.index
~ ㄴtrain2-0.ckpt.meta
~ ㄴtrain2-1.ckpt.data-00000-of-00001
~ ㄴtrain2-1.ckpt.index
~ ㄴtrain2-1.ckpt.meta
~ ㄴtrain2-2.ckpt.data-00000-of-00001
~ ㄴtrain2-2.ckpt.index
~ ㄴtrain2-2.ckpt.meta

. . .

~ ㄴtrain2-8.ckpt.data-00000-of-00001
~ ㄴtrain2-8.ckpt.index
~ ㄴtrain2-8.ckpt.meta
~ ㄴtrain2-9.ckpt.data-00000-of-00001
~ ㄴtrain2-9.ckpt.index
~ ㄴtrain2-9.ckpt.meta



2. checkpoint state proto



위의 사진은 tensorflow api 공식 doc : tf.train.Saver 사이트에서 볼 수 있는 Checkpoint State Protocol Buffer 에 대한 정보이다.

이름도 거대한 Checkpoint State Protocol Buffer 에 대해서 알 필요가 있다.

Saver 의 save 모듈을 이용해 모델을 저장할 때, Saver 는 Checkpoint State Protocol Buffer 를 /saved/checkpoint 파일에 담아 저장하고, 새로운 job으로 학습할 때 마다 업데이트해 저장한다.

Checkpoint State Protocol Buffer 에는 두가지 정보가 담겨있다.

  1. model_checkpoint_path : 가장 최근에 저장된 job.ckpt 파일의 path 정보
  2. all_model_checkpoint_paths : 최근에 저장된 job_i.ckpt 파일들의 path 정보 list

보통 saved 폴더에서 가장 최신의 체크포인트파일을 불러와 모델을 재학습시키거나 테스트해보려고 할때 사용한다.

all_model_checkpoint_paths 의 가장 마지막 원소는 model_checkpoint_path 와 동일하다.

Checkpoint State Protocol Buffer 이용법

대표적으로 두가지 방법이 있다.

  1. tf.train.get_checkpoint_state(saved_dir_path)
  2. tf.train.latest_checkpoint(saved_dir_path)

1. tf.train.get_checkpoint_state(saved_dir_path)

saved_dir_path 에서 checkpoint 파일 안의 Checkpoint State Protocol Buffer 를 읽어온다.

ckpt_state = tf.train.get_checkpoint_state("saved")

print(type(ckpt_state))

print("첫번째 정보 사용법:", ckpt_state.model_checkpoint_path)
print("두번째 정보 사용법:", ckpt_state.all_model_checkpoint_paths)
<class 'tensorflow.python.training.checkpoint_state_pb2.CheckpointState'>
첫번째 정보 사용법: saved/train2-9
두번째 정보 사용법: ['saved/train2-5', 'saved/train2-6', 'saved/train2-7', 'saved/train2-8', 'saved/train2-9']

2. tf.train.latest_checkpoint(saved_dir_path)

saved_dir_path 에서 checkpoint 파일 안의 Checkpoint State Protocol Buffer 에서 model_checkpoint_path 정보만 string 으로 반환한다.

recent_ckpt_job_path = tf.train.latest_checkpoint("saved")

print(recent_ckpt_job_path)
saved/train2-9



3. checkpoint 불러오기



저장한 체크포인트 파일들에서 Variable 들을 다시 꺼내서 재사용하려면 tf.Saver 클래스의 restore 메서들을 이용한다.

이때 위에서 언급한 Checkpoint State Protocol Buffer 가 매우 요긴하게 쓰인다.

test1 job 을 수행하는데에 train2 job 에서 마지막에 저장한 변수 ckpt 결과물을 로드해 사용하고싶다.

그렇다면 아래의 코드처럼 하면된다.

case 1: 직접 불러올 job.ckpt 명시해주는 경우

#/model/test.py

with tf.Session(graph=g) as sess :

    # Saver instance 를 생성한다.
    # Saver.restore(sess, ckpt_path)

    saver = tf.train.Saver()

    sess.run(tf.global_variables_initializer())

    # Saver.restore()
    # args : tf.Session, job`s checkpoint file path
    # return : None

    ckpt_path = saver.restore(sess, "saved/train2-9")

INFO:tensorflow:Restoring parameters from saved/train2-9

case 2: tf.train.latest_checkpoint(dir_path) 이용하는 경우

#/model/test.py

with tf.Session(graph=g) as sess :

    # Saver instance 를 생성한다.
    # Saver.restore(sess, ckpt_path)

    saver = tf.train.Saver()

    sess.run(tf.global_variables_initializer())

    # Saver.restore()
    # args : tf.Session, job`s checkpoint file path
    # return : None

    ckpt_path = saver.restore(sess, tf.train.latest_checkpoint("saved"))

INFO:tensorflow:Restoring parameters from saved/train2-9