import transformers
import torch
import pandas as pd
from tqdm import tqdm
from huggingface_hub import login

# ENTER YOUR HUGGINGFACE LOGIN ID HERE
login("your_login_ID")

# Define variables
modes = ["zs", "desc", "fs"]  # List of approaches
llm_names = ["llama", "mistral"]  # List of LLMs

# Model IDs for each LLM
model_ids = {
    "llama": "meta-llama/Llama-3.2-3B-Instruct",
    "mistral": "mistralai/Mistral-Nemo-Instruct-2407"
}

torch.manual_seed(20241114)

# Read input CSV file
df = pd.read_csv('data_llm_input.csv')

# Loop through each combination of mode and llm_name
for mode in modes:
    for llm_name in llm_names:
        model_id = model_ids[llm_name]  # Dynamically select model_id based on llm_name

        # Create pipeline
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )

        # Prepare new DataFrame for results
        new_df = pd.DataFrame(columns=['doc_id', f'prompt_{mode}', f'{llm_name}_{mode}_2_raw'])

        # Loop through rows in the input data
        for index, row in tqdm(df.iterrows(), total=len(df)):
            uniqid = row['doc_id']
            prompt = row[f'prompt_{mode}']
            messages = [
                {"role": "user", "content": prompt},
            ]

            terminators = [
                pipeline.tokenizer.eos_token_id,
                pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ]

            # Generate output
            outputs = pipeline(
                messages,
                eos_token_id=terminators,
                do_sample=False,
                temperature=0,
                max_new_tokens=600
            )
            response = outputs[0]["generated_text"][-1]['content']

            # Print output
            print(f"{mode}-ID: ", uniqid)
            print("Input: ", prompt)
            print("Output: ", response)
            print("")

            # Append to DataFrame
            new_row = {'doc_id': uniqid, f'prompt_{mode}': prompt, f'{llm_name}_{mode}_2_raw': response}
            new_df = pd.concat([new_df, pd.DataFrame([new_row])], ignore_index=True)

        # Save results to output CSV file
        output_file = f'completions_{llm_name}_{mode}_2.csv'
        new_df.to_csv(output_file, index=False)
        print(f"Results saved to {output_file}")