File size: 3,790 Bytes
f9abc90
49cbc74
9ac4ead
26db3f0
f9abc90
26db3f0
 
 
49cbc74
26db3f0
 
 
 
 
 
 
 
 
 
 
 
 
 
6d99887
 
26db3f0
 
 
 
 
 
49cbc74
 
 
 
 
 
 
 
26db3f0
f9abc90
 
 
 
 
49cbc74
 
 
 
 
 
 
6064267
49cbc74
 
6064267
49cbc74
6064267
49cbc74
6064267
 
 
 
 
 
 
f9abc90
26db3f0
 
 
 
 
 
 
 
 
 
6d99887
f9abc90
 
 
9ac4ead
0365768
49cbc74
 
 
 
ec28976
49cbc74
 
 
 
0365768
49cbc74
 
 
 
 
 
 
 
 
 
 
 
 
9ac4ead
 
 
 
 
 
 
 
 
 
 
 
 
 
0365768
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import enum
from pathlib import Path
from typing import Any, Literal

import torch
from diffusers.image_processor import PipelineImageInput
from pydantic import BaseModel, ConfigDict, Field

from qwenimage.types import DataRange
from wandml.foundation.datamodels import FluxInputs
from wandml.trainers.datamodels import ExperimentTrainerParameters


class QwenInputs(BaseModel):
    image: PipelineImageInput | None = None
    prompt: str| list[str] | None = None
    height: int|None = None
    width: int|None = None
    negative_prompt: str| list[str] | None = None
    true_cfg_scale: float = 1.0
    num_inference_steps: int = 50
    generator: torch.Generator | list[torch.Generator] | None = None
    max_sequence_length: int = 512
    vae_image_override: int | None = None
    latent_size_override: int | None = None

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        # extra="allow",
    )

class TrainingType(str, enum.Enum):
    IM2IM = "im2im"
    NAIVE = "naive"
    REGRESSION = "regression"

    @property
    def is_style(self):
        return self in [TrainingType.NAIVE, TrainingType.IM2IM]

class QuantOptions(str, enum.Enum):
    INT8WO = "int8wo"
    INT4WO = "int4wo"
    FP8ROW = "fp8row"

LossTermSpecType = int|float|dict[str,int|float]|None

class QwenLossTerms(BaseModel):
    mse: LossTermSpecType = 1.0
    triplet: LossTermSpecType = 0.0
    negative_mse: LossTermSpecType = 0.0
    distribution_matching: LossTermSpecType = 0.0
    pixel_triplet: LossTermSpecType = 0.0
    pixel_lpips: LossTermSpecType = 0.0
    pixel_mse: LossTermSpecType = 0.0
    pixel_distribution_matching: LossTermSpecType = 0.0
    adversarial: LossTermSpecType = 0.0
    teacher: LossTermSpecType = 0.0

    triplet_margin: float = 0.0
    triplet_min_abs_diff: float = 0.0
    teacher_steps: int = 4

    @property
    def pixel_terms(self) -> bool:
        return ("pixel_lpips", "pixel_mse", "pixel_triplet", "pixel_distribution_matching",)

class QwenConfig(ExperimentTrainerParameters):
    load_multi_view_lora: bool = False
    train_max_sequence_length: int = 512
    train_dist: str = "linear" # "logit-normal"
    train_shift: bool = True
    inference_dist: str = "linear"
    inference_shift: bool = True
    static_mu: float | None = None
    loss_weight_dist: str | None = None # "scaled_clipped_gaussian", "logit-normal"

    vae_image_size: int = 1024 * 1024
    offload_text_encoder: bool = True
    quantize_text_encoder: bool = False
    quantize_transformer: bool = False
    vae_tiling: bool = False


    train_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
    validation_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)

    training_type: TrainingType|None=None
    train_range: DataRange|None=None
    val_range: DataRange|None=None
    test_range: DataRange|None=None

    style_title: str|None = None
    style_base_dir: str|None = None
    style_csv_path: str|None = None
    style_data_dir: str|None = None
    style_ref_dir: str|None = None
    style_val_with: str = "train"
    naive_static_prompt: str|None = None

    regression_data_dir: str|Path|None = None
    regression_gen_steps: int = 50
    editing_data_dir: str|Path|None = None
    editing_total_per: int = 1
    regression_base_pipe_steps: int = 8

    name_suffix: dict[str,Any]|None = None

    def add_suffix_to_names(self):
        if self.name_suffix is None:
            return
        suffix_sum = ""
        for suf_name,suf_val in self.name_suffix.items():
            suffix_sum += "_" + suf_name
            suf_val = str(suf_val)
            suffix_sum += "_" + suf_val
        self.run_name += suffix_sum
        self.output_dir = self.output_dir.removesuffix("/") # in case
        self.output_dir += suffix_sum