|
|
@@ -6,6 +6,7 @@ import sys
|
|
|
import torch
|
|
|
from datasets import load_dataset
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
+from trl import SFTConfig, SFTTrainer
|
|
|
|
|
|
# Parse command line arguments
|
|
|
parser = argparse.ArgumentParser(
|
|
|
@@ -66,8 +67,6 @@ if use_quantized:
|
|
|
if use_peft:
|
|
|
from peft import LoraConfig
|
|
|
|
|
|
-from trl import setup_chat_format, SFTConfig, SFTTrainer
|
|
|
-
|
|
|
# Dataset configuration based on CoT parameter
|
|
|
if use_cot:
|
|
|
FT_DATASET = "train_text2sql_cot_dataset.json"
|