| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import gradio as gr |
| import torch |
| import autopep8 |
| import glob |
| import re |
| import os |
| from huggingface_hub import hf_hub_download |
|
|
|
|
| |
| |
| |
|
|
| def normalize_indentation(code): |
| """ |
| Normalize indentation in example code by removing excessive tabs. |
| Also removes any backslash characters. |
| """ |
| code = code.replace("\\", "") |
|
|
| lines = code.split("\n") |
| if not lines: |
| return "" |
|
|
| fixed_lines = [] |
| indent_fix_mode = False |
|
|
| for i, line in enumerate(lines): |
| if line.strip().startswith("def "): |
| fixed_lines.append(line) |
| indent_fix_mode = True |
| elif indent_fix_mode and line.strip(): |
| |
| if line.startswith("\t\t"): |
| fixed_lines.append("\t" + line[2:]) |
| elif line.startswith(" "): |
| fixed_lines.append(" " + line[8:]) |
| else: |
| fixed_lines.append(line) |
| else: |
| fixed_lines.append(line) |
|
|
| return "\n".join(fixed_lines) |
|
|
|
|
| def clear_text(text): |
| """ |
| Cleans text from escape sequences while preserving original formatting. |
| """ |
| temp_newline = "TEMP_NEWLINE_PLACEHOLDER" |
| temp_tab = "TEMP_TAB_PLACEHOLDER" |
|
|
| text = text.replace("\\n", temp_newline) |
| text = text.replace("\\t", temp_tab) |
|
|
| text = text.replace("\\", "") |
|
|
| text = text.replace(temp_newline, "\n") |
| text = text.replace(temp_tab, "\t") |
|
|
| return text |
|
|
|
|
| def encode_text(text): |
| """ |
| Encodes control characters into escape sequences. |
| """ |
| text = text.replace("\n", "\\n") |
| text = text.replace("\t", "\\t") |
| return text |
|
|
|
|
| def format_code(code): |
| """ |
| Format Python code using autopep8 with aggressive settings. |
| """ |
| try: |
| formatted_code = autopep8.fix_code( |
| code, |
| options={ |
| "aggressive": 2, |
| "max_line_length": 88, |
| "indent_size": 4, |
| }, |
| ) |
|
|
| |
| formatted_code = formatted_code.replace("( ", "(").replace(" )", ")") |
|
|
| for op in ["+", "-", "*", "/", "=", "==", "!=", ">=", "<=", ">", "<"]: |
| formatted_code = formatted_code.replace(f"{op} ", op + " ") |
| formatted_code = formatted_code.replace(f" {op}", " " + op) |
|
|
| formatted_code = re.sub(r"(\w+)\s+\(", r"\1(", formatted_code) |
|
|
| return formatted_code |
| except Exception as e: |
| print(f"Error formatting code: {str(e)}") |
| return code |
|
|
|
|
| def fix_common_syntax_issues(code): |
| """ |
| Fix common syntax issues in generated code without modifying indentation. |
| """ |
| lines = code.split("\n") |
| fixed_lines = [] |
|
|
| for line in lines: |
| stripped = line.strip() |
| if ( |
| stripped.startswith("if ") |
| or stripped.startswith("elif ") |
| or stripped.startswith("else") |
| or stripped.startswith("for ") |
| or stripped.startswith("while ") |
| or stripped.startswith("def ") |
| or stripped.startswith("class ") |
| ): |
| if not stripped.endswith(":") and not stripped.endswith("\\"): |
| line = line.rstrip() + ":" |
|
|
| fixed_lines.append(line) |
|
|
| code = "\n".join(fixed_lines) |
|
|
| |
| quote_chars = ['"', "'"] |
| for quote in quote_chars: |
| if code.count(quote) % 2 != 0: |
| lines = code.split("\n") |
| for i, line in enumerate(lines): |
| if line.count(quote) % 2 != 0: |
| lines[i] = line.rstrip() + quote |
| break |
| code = "\n".join(lines) |
|
|
| |
| pattern = r"(\w+)\s*\([^)]*$" |
| if re.search(pattern, code): |
| lines = code.split("\n") |
| for i, line in enumerate(lines): |
| if re.search(pattern, line) and not any( |
| lines[j].strip().startswith(")") |
| for j in range(i + 1, min(i + 3, len(lines))) |
| ): |
| lines[i] = line.rstrip() + ")" |
| code = "\n".join(lines) |
|
|
| return code |
|
|
|
|
| def load_example_from_file(example_path): |
| """ |
| Load example from a file with format: |
| description_BREAK_code |
| where 'code' uses \\n and \\t for formatting. |
| """ |
| try: |
| with open(example_path, "r") as f: |
| content = f.read() |
|
|
| parts = content.split("_BREAK_") |
| if len(parts) == 2: |
| description = parts[0].strip() |
| code = parts[1].strip() |
|
|
| code = code.replace("\\n", "\n").replace("\\t", "\t") |
| code = normalize_indentation(code) |
|
|
| return description, code |
| else: |
| print(f"Invalid format in example file: {example_path}") |
| return "", "" |
| except Exception as e: |
| print(f"Error loading example file {example_path}: {str(e)}") |
| return "", "" |
|
|
|
|
| def find_example_files(): |
| """ |
| Find all raw.in example files in the examples directory. |
| """ |
| example_files = glob.glob("examples/*/raw.in") |
| return example_files |
|
|
|
|
| |
| |
| |
|
|
| BASE_MODEL_ID = "Salesforce/codet5p-770m" |
| FINETUNED_REPO_ID = "OSS-forge/codet5p-770m-pyresbugs" |
| FINETUNED_FILENAME = "pytorch_model.bin" |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| print(f"Loading tokenizer from base model: {BASE_MODEL_ID}") |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) |
|
|
| print(f"Loading base model: {BASE_MODEL_ID}") |
| model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID) |
| model.to(device) |
|
|
| print(f"Downloading fine-tuned weights from repo: {FINETUNED_REPO_ID}") |
| ckpt_path = hf_hub_download(FINETUNED_REPO_ID, FINETUNED_FILENAME) |
|
|
| print(f"Loading state_dict from: {ckpt_path}") |
| state_dict = torch.load(ckpt_path, map_location="cpu") |
|
|
| if "model_state_dict" in state_dict: |
| state_dict = state_dict["model_state_dict"] |
|
|
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| print(f"Loaded fine-tuned weights. Missing keys: {len(missing)}, unexpected keys: {len(unexpected)}") |
|
|
| model.eval() |
|
|
|
|
|
|
|
|
| |
| |
| |
|
|
| |
| current_code = None |
| bug_counter = 0 |
|
|
|
|
| def generate_bugged_code(description, code, chat_history, is_first_time): |
| global current_code, bug_counter |
|
|
| if chat_history is None: |
| chat_history = [] |
|
|
| if is_first_time: |
| bug_counter = 0 |
| current_code = None |
| chat_history = [] |
|
|
| bug_counter += 1 |
|
|
| if bug_counter == 1: |
| input_for_model = code |
| input_type = "original" |
| else: |
| if current_code is None: |
| return chat_history, gr.update(value=""), False |
| input_for_model = current_code |
| input_type = "previous bugged code" |
|
|
| print(f"Using {input_type} - counter: {bug_counter}\n{input_for_model}") |
|
|
| encoded_code = encode_text(input_for_model) |
| combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}" |
|
|
| inputs = tokenizer( |
| combined_input, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512, |
| ).input_ids.to(device) |
|
|
| try: |
| print("Starting generation...") |
| with torch.no_grad(): |
| outputs = model.generate( |
| inputs, |
| max_new_tokens=256, |
| num_beams=1, |
| do_sample=False, |
| early_stopping=True, |
| ) |
| print("Generation done.") |
| except Exception as e: |
| print("Generation error:", repr(e)) |
| raise e |
|
|
| bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| bugged_code = clear_text(bugged_code_escaped) |
| bugged_code = fix_common_syntax_issues(bugged_code) |
| bugged_code = format_code(bugged_code) |
|
|
| current_code = bugged_code |
|
|
| user_message = f"**Description**: {description}" |
| if input_type == "original": |
| user_message += f"\n\n**Original code**:\n```python\n{input_for_model}\n```" |
| else: |
| user_message += ( |
| f"\n\n**Previous bugged code**:\n```python\n{input_for_model}\n```" |
| ) |
|
|
| ai_message = f"**Bugged code**:\n```python\n{bugged_code}\n```" |
|
|
| chat_history = chat_history + [ |
| {"role": "user", "content": user_message}, |
| {"role": "assistant", "content": ai_message}, |
| ] |
|
|
| return chat_history, gr.update(value=""), False |
|
|
|
|
|
|
|
|
| def reset_interface(): |
| global current_code, bug_counter |
| current_code = None |
| bug_counter = 0 |
| return [], gr.update(value=""), True |
|
|
|
|
| example_files = find_example_files() |
| example_names = [ |
| f"Example {i+1}: {os.path.basename(os.path.dirname(f))}" |
| for i, f in enumerate(example_files) |
| ] |
|
|
|
|
| def load_example(example_index): |
| if example_index < len(example_files): |
| return load_example_from_file(example_files[example_index]) |
| return "", "" |
|
|
|
|
| with gr.Blocks(title="Software-Fault Injection from NL") as demo: |
| gr.Markdown("# 🐞 Software-Fault Injection from Natural Language") |
| gr.Markdown( |
| "Generate Python code with specific bugs based on a description and original code. " |
| "The model used is **BugGen (CodeT5+ 770M, PyResBugs)**." |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| description_input = gr.Textbox( |
| label="Bug Description", |
| placeholder="Describe the type of bug to introduce...", |
| lines=3, |
| ) |
| code_input = gr.Code( |
| label="Original Code", |
| language="python", |
| lines=12, |
| ) |
|
|
| is_first = gr.State(True) |
|
|
| submit_btn = gr.Button("Generate Bugged Code") |
| reset_btn = gr.Button("Start Over") |
|
|
| gr.Markdown("### Examples") |
| example_buttons = [gr.Button(name) for name in example_names] |
|
|
| with gr.Column(scale=3): |
| chat_output = gr.Chatbot( |
| label="Conversation", |
| height=500, |
| ) |
|
|
| for i, btn in enumerate(example_buttons): |
| btn.click( |
| fn=lambda i=i: load_example(i), |
| outputs=[description_input, code_input], |
| ) |
|
|
| submit_btn.click( |
| fn=generate_bugged_code, |
| inputs=[description_input, code_input, chat_output, is_first], |
| outputs=[chat_output, description_input, is_first], |
| ) |
|
|
| reset_btn.click( |
| fn=reset_interface, |
| outputs=[chat_output, description_input, is_first], |
| ) |
|
|
| print("Launching Gradio interface...") |
| demo.queue(max_size=10).launch() |