Junjie96 commited on
Commit
a397ba1
1 Parent(s): 83db97d

Update src/util.py

Browse files
Files changed (1) hide show
  1. src/util.py +10 -15
src/util.py CHANGED
@@ -3,7 +3,6 @@ import io
3
  import os
4
  import time
5
 
6
- import cv2
7
  import oss2
8
  import requests
9
  from PIL import Image
@@ -20,17 +19,13 @@ bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, buck
20
  oss_path = os.getenv("OSS_PATH")
21
 
22
 
23
- def resize(img, short_side_length=512):
24
- height, width, _ = img.shape
25
- aspect_ratio = width / height
26
- if width > height:
27
- new_width = short_side_length
28
- new_height = int(new_width / aspect_ratio)
29
- else:
30
- new_height = short_side_length
31
- new_width = int(new_height * aspect_ratio)
32
- resized_img = cv2.resize(img, (new_width, new_height))
33
- return resized_img
34
 
35
 
36
  def download_img_pil(index, img_url):
@@ -60,13 +55,13 @@ def download_images(img_urls, batch_size):
60
 
61
  def upload_np_2_oss(input_image, name="cache.jpg"):
62
  assert name.lower().endswith((".png", ".jpg")), name
63
- if name.lower().endswith(".png"):
64
  name = name[:-4] + ".jpg"
65
  imgByteArr = io.BytesIO()
66
  if name.lower().endswith(".png"):
67
- Image.fromarray(resize(input_image)).save(imgByteArr, format="PNG")
68
  else:
69
- Image.fromarray(resize(input_image)).save(imgByteArr, format="JPEG", quality=95)
70
  imgByteArr = imgByteArr.getvalue()
71
 
72
  start_time = time.perf_counter()
 
3
  import os
4
  import time
5
 
 
6
  import oss2
7
  import requests
8
  from PIL import Image
 
19
  oss_path = os.getenv("OSS_PATH")
20
 
21
 
22
+ def resize(image, short_side_length=512):
23
+ width, height = image.size
24
+ ratio = short_side_length / min(width, height)
25
+ new_width = int(width * ratio)
26
+ new_height = int(height * ratio)
27
+ resized_image = image.resize((new_width, new_height))
28
+ return resized_image
 
 
 
 
29
 
30
 
31
  def download_img_pil(index, img_url):
 
55
 
56
  def upload_np_2_oss(input_image, name="cache.jpg"):
57
  assert name.lower().endswith((".png", ".jpg")), name
58
+ if name.endswith(".png"):
59
  name = name[:-4] + ".jpg"
60
  imgByteArr = io.BytesIO()
61
  if name.lower().endswith(".png"):
62
+ resize(Image.fromarray(input_image)).save(imgByteArr, format="PNG")
63
  else:
64
+ resize(Image.fromarray(input_image)).save(imgByteArr, format="JPEG", quality=95)
65
  imgByteArr = imgByteArr.getvalue()
66
 
67
  start_time = time.perf_counter()