File size: 4,749 Bytes
77772bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7155406
 
 
 
 
 
 
 
 
 
77772bc
 
7155406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77772bc
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import io
import base64
from flask import Flask, render_template, request
from datasets import load_dataset, Features, Value, Image, Sequence
from PIL import Image as PILImage  # 给 PIL 的 Image 起个别名

app = Flask(__name__)

# 替换为你的数据集 ID
DATASET_ID = "ma-xu/fine-t2i"

def image_to_base64(pil_img):
    # 兼容性处理:尝试获取新版本的 Resampling 属性,如果不存在则使用旧版本
    try:
        resampling_mode = PILImage.Resampling.LANCZOS
    except AttributeError:
        resampling_mode = PILImage.LANCZOS

    # 1. 转换为缩略图提高加载速度,保持 aspect ratio
    max_size = (800, 800)
    pil_img.thumbnail(max_size, resampling_mode)
    
    # 2. 转换为 RGB 模式(JPEG 不支持 RGBA)
    if pil_img.mode in ("RGBA", "P"):
        pil_img = pil_img.convert("RGB")
        
    # 3. 编码为 Base64
    byte_arr = io.BytesIO()
    pil_img.save(byte_arr, format='JPEG', quality=85)
    encoded_img = base64.b64encode(byte_arr.getvalue()).decode('utf-8')
    return f"data:image/jpeg;base64,{encoded_img}"

@app.route('/', methods=['GET', 'POST'])
def index():
    samples = []
    filter_folder = request.form.get('folder', 'synthetic_enhanced_prompt_random_resolution')
    print(f"filter_folder: {filter_folder}")
    if filter_folder == "curated":
        url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/curated/train-*.tar"
    elif filter_folder == "enhanced_prompt_random_resolution":
        url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/synthetic_enhanced_prompt_random_resolution/train-*.tar"
    elif filter_folder == "enhanced_prompt_square_resolution":
        url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/synthetic_enhanced_prompt_square_resolution/train-*.tar"
    elif filter_folder == "original_prompt_random_resolution":
        url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/synthetic_original_prompt_random_resolution/train-*.tar"
    else:
        url_pattern = "https://huggingface.co/datasets/ma-xu/fine-t2i/resolve/main/synthetic_original_prompt_square_resolution/train-*.tar"

    
    if request.method == 'POST' or request.method == 'GET' :
        json_feature = {
            "aesthetic_predictor_v_2_5_score": Value("double"),
            "enhanced_length": Value("int64"),
            "enhanced_prompt": Value("string"),
            "enhancer": Value("string"),
            "id": Value("string"),
            "image_aspect_ratio": Value("string"),
            "image_generated_with_enhanced_prompt": Value("bool"),
            "image_generator": Value("string"),
            "image_resolution": Sequence(Value("int64")),
            "length": Value("int64"),
            "prompt": Value("string"),
            "prompt_category": Value("string"),
            "prompt_generator": Value("string"),
            "style": Value("string"),
            "task": Sequence(Value("string"))  # 对应 list<item: string>
        }
        features = Features({
            "__key__": Value("string"),
            "__url__": Value("string"),
            "jpg": Image(), 
            "txt": Value("string"),
            "json": json_feature, # 这里对应报错中的 list<item: string>
            # 如果 json 实际上是字典,可以改为 Value("string") 并在之后手动 json.loads
        })
        try:
            dataset = load_dataset(
                "webdataset", 
                data_files={"train": url_pattern}, 
                split="train",
                features=features,  # 强制指定特征 
                streaming=True  # please do streaming, or you will download the whole dataset ~2TB
            )
            ds = dataset.shuffle(buffer_size=100, )
            count = 0

            
            for item in ds:
                if count >= 100: break
                print( count)
                # 这里的字段名需对应你数据集中的 key (jpg, json, txt)
                text_content = item.get('txt', '')
                json_data = item.get('json', {})
                
                samples.append({
                    "image": image_to_base64(item['jpg']),
                    "text": text_content,
                    "json": json_data,
                    "tar_file": item["__url__"]
                })
                count += 1
        except Exception as e:
            print(f"Runtime Error: {e}") # 这条信息会显示在 HF 的 Logs 里
            return "Error loading dataset. Please check logs.", 500

    return render_template('index.html', samples=samples)

if __name__ == '__main__':
    # HF Space 必须使用 7860 端口
    app.run(host='0.0.0.0', port=7860)