Spaces:
Running
Running
| import io | |
| import base64 | |
| from flask import Flask, render_template, request | |
| from datasets import load_dataset, Features, Value, Image, Sequence | |
| from PIL import Image as PILImage # 给 PIL 的 Image 起个别名 | |
| app = Flask(__name__) | |
| # 替换为你的数据集 ID | |
| DATASET_ID = "ma-xu/fine-t2i" | |
| def image_to_base64(pil_img): | |
| # 兼容性处理:尝试获取新版本的 Resampling 属性,如果不存在则使用旧版本 | |
| try: | |
| resampling_mode = PILImage.Resampling.LANCZOS | |
| except AttributeError: | |
| resampling_mode = PILImage.LANCZOS | |
| # 1. 转换为缩略图提高加载速度,保持 aspect ratio | |
| max_size = (800, 800) | |
| pil_img.thumbnail(max_size, resampling_mode) | |
| # 2. 转换为 RGB 模式(JPEG 不支持 RGBA) | |
| if pil_img.mode in ("RGBA", "P"): | |
| pil_img = pil_img.convert("RGB") | |
| # 3. 编码为 Base64 | |
| byte_arr = io.BytesIO() | |
| pil_img.save(byte_arr, format='JPEG', quality=85) | |
| encoded_img = base64.b64encode(byte_arr.getvalue()).decode('utf-8') | |
| return f"data:image/jpeg;base64,{encoded_img}" | |
| def index(): | |
| samples = [] | |
| filter_folder = request.form.get('folder', 'synthetic_enhanced_prompt_random_resolution') | |
| print(f"filter_folder: {filter_folder}") | |
| if filter_folder == "curated": | |
| url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/curated/train-*.tar" | |
| elif filter_folder == "enhanced_prompt_random_resolution": | |
| url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/synthetic_enhanced_prompt_random_resolution/train-*.tar" | |
| elif filter_folder == "enhanced_prompt_square_resolution": | |
| url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/synthetic_enhanced_prompt_square_resolution/train-*.tar" | |
| elif filter_folder == "original_prompt_random_resolution": | |
| url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/synthetic_original_prompt_random_resolution/train-*.tar" | |
| else: | |
| url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/synthetic_original_prompt_square_resolution/train-*.tar" | |
| if request.method == 'POST' or request.method == 'GET' : | |
| json_feature = { | |
| "aesthetic_predictor_v_2_5_score": Value("double"), | |
| "enhanced_length": Value("int64"), | |
| "enhanced_prompt": Value("string"), | |
| "enhancer": Value("string"), | |
| "id": Value("string"), | |
| "image_aspect_ratio": Value("string"), | |
| "image_generated_with_enhanced_prompt": Value("bool"), | |
| "image_generator": Value("string"), | |
| "image_resolution": Sequence(Value("int64")), | |
| "length": Value("int64"), | |
| "prompt": Value("string"), | |
| "prompt_category": Value("string"), | |
| "prompt_generator": Value("string"), | |
| "style": Value("string"), | |
| "task": Sequence(Value("string")) # 对应 list<item: string> | |
| } | |
| features = Features({ | |
| "__key__": Value("string"), | |
| "__url__": Value("string"), | |
| "jpg": Image(), | |
| "txt": Value("string"), | |
| "json": json_feature, # 这里对应报错中的 list<item: string> | |
| # 如果 json 实际上是字典,可以改为 Value("string") 并在之后手动 json.loads | |
| }) | |
| try: | |
| dataset = load_dataset( | |
| "webdataset", | |
| data_files={"train": url_pattern}, | |
| split="train", | |
| features=features, # 强制指定特征 | |
| streaming=True # please do streaming, or you will download the whole dataset ~2TB | |
| ) | |
| ds = dataset.shuffle(buffer_size=100, ) | |
| count = 0 | |
| for item in ds: | |
| if count >= 100: break | |
| print( count) | |
| # 这里的字段名需对应你数据集中的 key (jpg, json, txt) | |
| text_content = item.get('txt', '') | |
| json_data = item.get('json', {}) | |
| samples.append({ | |
| "image": image_to_base64(item['jpg']), | |
| "text": text_content, | |
| "json": json_data, | |
| "tar_file": item["__url__"] | |
| }) | |
| count += 1 | |
| except Exception as e: | |
| print(f"Runtime Error: {e}") # 这条信息会显示在 HF 的 Logs 里 | |
| return "Error loading dataset. Please check logs.", 500 | |
| return render_template('index.html', samples=samples) | |
| if __name__ == '__main__': | |
| # HF Space 必须使用 7860 端口 | |
| app.run(host='0.0.0.0', port=7860) |