Spaces:
Running
Running
File size: 4,749 Bytes
77772bc 7155406 77772bc 7155406 77772bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import io
import base64
from flask import Flask, render_template, request
from datasets import load_dataset, Features, Value, Image, Sequence
from PIL import Image as PILImage # 给 PIL 的 Image 起个别名
app = Flask(__name__)
# 替换为你的数据集 ID
DATASET_ID = "ma-xu/fine-t2i"
def image_to_base64(pil_img):
# 兼容性处理:尝试获取新版本的 Resampling 属性,如果不存在则使用旧版本
try:
resampling_mode = PILImage.Resampling.LANCZOS
except AttributeError:
resampling_mode = PILImage.LANCZOS
# 1. 转换为缩略图提高加载速度,保持 aspect ratio
max_size = (800, 800)
pil_img.thumbnail(max_size, resampling_mode)
# 2. 转换为 RGB 模式(JPEG 不支持 RGBA)
if pil_img.mode in ("RGBA", "P"):
pil_img = pil_img.convert("RGB")
# 3. 编码为 Base64
byte_arr = io.BytesIO()
pil_img.save(byte_arr, format='JPEG', quality=85)
encoded_img = base64.b64encode(byte_arr.getvalue()).decode('utf-8')
return f"data:image/jpeg;base64,{encoded_img}"
@app.route('/', methods=['GET', 'POST'])
def index():
samples = []
filter_folder = request.form.get('folder', 'synthetic_enhanced_prompt_random_resolution')
print(f"filter_folder: {filter_folder}")
if filter_folder == "curated":
url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/curated/train-*.tar"
elif filter_folder == "enhanced_prompt_random_resolution":
url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/synthetic_enhanced_prompt_random_resolution/train-*.tar"
elif filter_folder == "enhanced_prompt_square_resolution":
url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/synthetic_enhanced_prompt_square_resolution/train-*.tar"
elif filter_folder == "original_prompt_random_resolution":
url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/synthetic_original_prompt_random_resolution/train-*.tar"
else:
url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/synthetic_original_prompt_square_resolution/train-*.tar"
if request.method == 'POST' or request.method == 'GET' :
json_feature = {
"aesthetic_predictor_v_2_5_score": Value("double"),
"enhanced_length": Value("int64"),
"enhanced_prompt": Value("string"),
"enhancer": Value("string"),
"id": Value("string"),
"image_aspect_ratio": Value("string"),
"image_generated_with_enhanced_prompt": Value("bool"),
"image_generator": Value("string"),
"image_resolution": Sequence(Value("int64")),
"length": Value("int64"),
"prompt": Value("string"),
"prompt_category": Value("string"),
"prompt_generator": Value("string"),
"style": Value("string"),
"task": Sequence(Value("string")) # 对应 list<item: string>
}
features = Features({
"__key__": Value("string"),
"__url__": Value("string"),
"jpg": Image(),
"txt": Value("string"),
"json": json_feature, # 这里对应报错中的 list<item: string>
# 如果 json 实际上是字典,可以改为 Value("string") 并在之后手动 json.loads
})
try:
dataset = load_dataset(
"webdataset",
data_files={"train": url_pattern},
split="train",
features=features, # 强制指定特征
streaming=True # please do streaming, or you will download the whole dataset ~2TB
)
ds = dataset.shuffle(buffer_size=100, )
count = 0
for item in ds:
if count >= 100: break
print( count)
# 这里的字段名需对应你数据集中的 key (jpg, json, txt)
text_content = item.get('txt', '')
json_data = item.get('json', {})
samples.append({
"image": image_to_base64(item['jpg']),
"text": text_content,
"json": json_data,
"tar_file": item["__url__"]
})
count += 1
except Exception as e:
print(f"Runtime Error: {e}") # 这条信息会显示在 HF 的 Logs 里
return "Error loading dataset. Please check logs.", 500
return render_template('index.html', samples=samples)
if __name__ == '__main__':
# HF Space 必须使用 7860 端口
app.run(host='0.0.0.0', port=7860) |