import transformers
import torch
import pandas as pd
import accelerate
from tqdm import tqdm
from huggingface_hub import login

# ENTER YOUR HUGGINGFACE LOGIN ID HERE
login("your_login_ID")

# Define variables
mode = "zs" # Approach ("zs", "desc", or "fs")
model_id = "mistralai/Mistral-Nemo-Instruct-2407" # LLM ("meta-llama/Llama-3.2-3B-Instruct" or "mistralai/Mistral-Nemo-Instruct-2407")
llm_name = "mistral" # LLM ("llama" or "mistral")
torch.manual_seed(20241114)

# Create pipeline
pipeline = transformers.pipeline(
                                "text-generation",
                                model=model_id,
                                model_kwargs={"torch_dtype": torch.bfloat16},
                                device_map="auto",
                                )

# Read input CSV file
df = pd.read_csv(f'data_input_test.csv') # for testing
new_df = pd.DataFrame(columns=['doc_id', f'prompt_{mode}', f'{llm_name}_{mode}_raw']) # select ID column and correct prompt version column;
                                                                                     # adapt last column to MODEL_mode_raw.

# Loop through rows
for index, row in tqdm(df.iterrows()):
        uniqid = row['doc_id']
        prompt = row[f'prompt_{mode}']
        messages = [
            {"role": "user", "content": prompt},
        ]
        
        terminators = [
            pipeline.tokenizer.eos_token_id,
            pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
        
        outputs = pipeline(
            messages,
            eos_token_id=terminators,
            do_sample=False,
            temperature=0,
            max_new_tokens = 600
            #top_p=1.0 # default
        )
        response = outputs[0]["generated_text"][-1]['content']
        print(f"{mode}-ID: ", uniqid)
        print("Input: ", prompt)
        print("Output: ", response)
        print("")
        new_row = {'doc_id': uniqid, f'prompt_{mode}': prompt, f'{llm_name}_{mode}_raw': response}
        new_df = pd.concat([new_df, pd.DataFrame([new_row])], ignore_index=True)

# Save results to output CSV file
new_df.to_csv(f'completions_{llm_name}_{mode}_test.csv', index=False)
