wahaha commited on
Commit
f937c5e
1 Parent(s): df7ebc5
Files changed (2) hide show
  1. app.py +22 -11
  2. test1.py +26 -13
app.py CHANGED
@@ -18,6 +18,7 @@ from io import BytesIO
18
  sys.path.insert(0, 'animeganv2')
19
 
20
  import test1 as test
 
21
 
22
  ORIGINAL_REPO_URL = 'https://github.com/TachibanaYoshino/AnimeGANv2'
23
  TITLE = 'TachibanaYoshino/AnimeGANv2'
@@ -48,17 +49,16 @@ def parse_args() -> argparse.Namespace:
48
 
49
  def run(
50
  image,
51
- checkpoint_dir: str,
52
- ) -> tuple[PIL.Image.Image]:
 
 
53
 
54
- curPath = os.path.abspath(os.path.dirname(__file__))
55
- checkpoint_dir = os.path.join(curPath, 'animeganv2/checkpoint/generator_Shinkai_weight')
56
- print(checkpoint_dir)
57
-
58
- outs = test.test(checkpoint_dir, 'save_dir', image.name, True)
59
 
60
-
61
- return PIL.Image.open(outs[0])
62
 
63
 
64
  def main():
@@ -66,8 +66,13 @@ def main():
66
 
67
  args = parse_args()
68
 
 
 
 
 
 
69
 
70
- func = functools.partial(run, checkpoint_dir='')
71
  func = functools.update_wrapper(func, run)
72
 
73
 
@@ -79,7 +84,13 @@ def main():
79
  [
80
  gr.outputs.Image(
81
  type='pil',
82
- label='Result'),
 
 
 
 
 
 
83
  ],
84
  #examples=examples,
85
  theme=args.theme,
 
18
  sys.path.insert(0, 'animeganv2')
19
 
20
  import test1 as test
21
+ from test1 import ImportGraph
22
 
23
  ORIGINAL_REPO_URL = 'https://github.com/TachibanaYoshino/AnimeGANv2'
24
  TITLE = 'TachibanaYoshino/AnimeGANv2'
 
49
 
50
  def run(
51
  image,
52
+ shinkai: ImportGraph,
53
+ hayao: ImportGraph,
54
+ paprika: ImportGraph,
55
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
56
 
57
+ im1 = shinkai.test(image.name, True)
58
+ im2 = hayao.test(image.name, True)
59
+ im3 = paprika.test(image.name, True)
 
 
60
 
61
+ return PIL.Image.open(im1),PIL.Image.open(im2),PIL.Image.open(im3)
 
62
 
63
 
64
  def main():
 
66
 
67
  args = parse_args()
68
 
69
+ curPath = os.path.abspath(os.path.dirname(__file__))
70
+ #init
71
+ shinkai = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Shinkai_weight'))
72
+ hayao = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Hayao_weight'))
73
+ paprika = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Paprika_weight'))
74
 
75
+ func = functools.partial(run, shinkai=shinkai,hayao=hayao,paprika=paprika )
76
  func = functools.update_wrapper(func, run)
77
 
78
 
 
84
  [
85
  gr.outputs.Image(
86
  type='pil',
87
+ label='Shinkai Result'),
88
+ gr.outputs.Image(
89
+ type='pil',
90
+ label='Hayao Result'),
91
+ gr.outputs.Image(
92
+ type='pil',
93
+ label='Paprika Result'),
94
  ],
95
  #examples=examples,
96
  theme=args.theme,
test1.py CHANGED
@@ -8,22 +8,35 @@ import numpy as np
8
  from net import generator
9
  os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
10
 
11
- def parse_args():
12
- desc = "AnimeGANv2"
13
- parser = argparse.ArgumentParser(description=desc)
 
 
 
 
 
 
14
 
15
- parser.add_argument('--checkpoint_dir', type=str, default='checkpoint/'+'generator_Shinkai_weight',
16
- help='Directory name to save the checkpoints')
17
- parser.add_argument('--test_dir', type=str, default='dataset/test/t',
18
- help='Directory name of test photos')
19
- parser.add_argument('--save_dir', type=str, default='Shinkai/t',
20
- help='what style you want to get')
21
- parser.add_argument('--if_adjust_brightness', type=bool, default=True,
22
- help='adjust brightness by the real photo')
 
 
 
 
 
 
 
 
23
 
24
- """checking arguments"""
25
 
26
- return parser.parse_args()
27
 
28
  def stats_graph(graph):
29
  flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
 
8
  from net import generator
9
  os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
10
 
11
+ class ImportGraph:
12
+ def __init__(self, checkpoint_dir):
13
+ self.graph = tf.Graph()
14
+ self.sess = tf.Session(graph=self.graph, config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options))
15
+ with self.graph.as_default():
16
+ test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
17
+ with tf.variable_scope("generator", reuse=False):
18
+ test_generated = generator.G_net(test_real).fake
19
+ saver = tf.train.Saver()
20
 
21
+ ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
22
+ if ckpt and ckpt.model_checkpoint_path:
23
+ ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line
24
+ saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
25
+ print(" [*] Success to read {}".format(os.path.join(checkpoint_dir, ckpt_name)))
26
+ else:
27
+ print(" [*] Failed to find a checkpoint")
28
+
29
+ def test(self, sample_file, if_adjust_brightness, img_size=[256,256]):
30
+ sample_image = np.asarray(load_test_data(sample_file, img_size))
31
+ image_path = os.path.join(result_dir, '{0}'.format(os.path.basename(sample_file)))
32
+ fake_img = sess.run(test_generated, feed_dict={test_real: sample_image})
33
+ if if_adjust_brightness:
34
+ save_images(fake_img, image_path, sample_file)
35
+ else:
36
+ save_images(fake_img, image_path, None)
37
 
38
+ return image_path
39
 
 
40
 
41
  def stats_graph(graph):
42
  flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())