Documenti di Didattica
Documenti di Professioni
Documenti di Cultura
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('X', 'data/tfrecords/apple.tfrecords',
'X tfrecords file for training, default:
data/tfrecords/apple.tfrecords')
tf.flags.DEFINE_string('Y', 'data/tfrecords/orange.tfrecords',
'Y tfrecords file for training, default:
data/tfrecords/orange.tfrecords')
tf.flags.DEFINE_string('load_model', None,
'folder of saved model that you wish to continue training
(e.g. 20170602-1936), default: None')
def train():
if FLAGS.load_model is not None:
checkpoints_dir = "checkpoints/" + FLAGS.load_model
else:
current_time = datetime.now().strftime("%Y%m%d-%H%M")
#checkpoints_dir = "checkpoints/{}".format(current_time)
checkpoints_dir="checkpoints/Hazy2GT"
try:
os.makedirs(checkpoints_dir)
except os.error:
pass
graph = tf.Graph()
with graph.as_default():
cycle_gan = CycleGAN(
X_train_file=FLAGS.X,
Y_train_file=FLAGS.Y,
batch_size=FLAGS.batch_size,
image_size1=FLAGS.image_size1,
image_size2=FLAGS.image_size2,
use_lsgan=FLAGS.use_lsgan,
norm=FLAGS.norm,
lambda1=FLAGS.lambda1,
lambda2=FLAGS.lambda2,
learning_rate=FLAGS.learning_rate,
beta1=FLAGS.beta1,
ngf=FLAGS.ngf
)
G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)
summary_op = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
saver = tf.train.Saver()
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.5
with tf.Session(config=config, graph=graph) as sess:
if FLAGS.load_model is not None:
checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
meta_graph_path = "checkpoints/Hazy2GT/model.ckpt-200000.meta"
print(tf.train.latest_checkpoint(checkpoints_dir))
restore = tf.train.import_meta_graph(meta_graph_path)
#restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
restore.restore(sess, "checkpoints/Hazy2GT/model.ckpt-200000")
step = int(meta_graph_path.split("-")[1].split(".")[0])
else:
sess.run(tf.global_variables_initializer())
step = 0
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
fake_Y_pool = ImagePool(FLAGS.pool_size)
fake_X_pool = ImagePool(FLAGS.pool_size)
# train
_, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
sess.run(
[optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
)
)
train_writer.add_summary(summary, step)
train_writer.flush()
if step % 100 == 0:
logging.info('-----------Step %d:-------------' % step)
logging.info(' G_loss : {}'.format(G_loss_val))
logging.info(' D_Y_loss : {}'.format(D_Y_loss_val))
logging.info(' F_loss : {}'.format(F_loss_val))
logging.info(' D_X_loss : {}'.format(D_X_loss_val))
if step % 10000 == 0:
save_path = saver.save(sess, checkpoints_dir + "/model.ckpt",
global_step=step)
logging.info("Model saved in file: %s" % save_path)
subprocess.call("./create_model.sh")
step += 1
except KeyboardInterrupt:
logging.info('Interrupted')
coord.request_stop()
except Exception as e:
coord.request_stop(e)
finally:
save_path = saver.save(sess, checkpoints_dir + "/model.ckpt",
global_step=step)
logging.info("Model saved in file: %s" % save_path)
# When done, ask the threads to stop.
coord.request_stop()
coord.join(threads)
def main(unused_argv):
train()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
tf.app.run()