Tree Prompting: Efficient Task Adaptation without Fine-Tuning, code for the Tree-prompt paper.
Tree Prompting uses training examples to learn a tree of prompts to make a classificationg, yielding higher accuracy and better efficiency that baseline ensembles.
For a simple scikit-learn interface to use Tree-Prompt, use the imodelsX package. Installation: pip install imodelsx
from imodelsx import TreePromptClassifier
import datasets
import numpy as np
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# set up data
rng = np.random.default_rng(seed=42)
dset_train = datasets.load_dataset('rotten_tomatoes')['train']
dset_train = dset_train.select(rng.choice(
len(dset_train), size=100, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(rng.choice(
len(dset_val), size=100, replace=False))
# set up arguments
prompts = [
"This movie is",
" Positive or Negative? The movie was",
" The sentiment of the movie was",
" The plot of the movie was really",
" The acting in the movie was",
]
verbalizer = {0: " Negative.", 1: " Positive."}
checkpoint = "gpt2"
# fit model
m = TreePromptClassifier(
checkpoint=checkpoint,
prompts=prompts,
verbalizer=verbalizer,
cache_prompt_features_dir=None, # 'cache_prompt_features_dir/gp2',
)
m.fit(dset_train["text"], dset_train["label"])
# compute accuracy
preds = m.predict(dset_val['text'])
print('\nTree-Prompt acc (val) ->',
np.mean(preds == dset_val['label'])) # -> 0.7
# compare to accuracy for individual prompts
for i, prompt in enumerate(prompts):
print(i, prompt, '->', m.prompt_accs_[i]) # -> 0.65, 0.5, 0.5, 0.56, 0.51
# visualize decision tree
plot_tree(
m.clf_,
fontsize=10,
feature_names=m.feature_names_,
class_names=list(verbalizer.values()),
filled=True,
)
plt.show()
Reference:
@misc{ch2022augmenting,
title={Tree Prompting: Efficient Task Adaptation without Fine-Tuning},
year={2023},
archivePrefix={arXiv},
primaryClass={cs.AI}
}
tprompt
: contains main code for modeling (e.g. model architecture)experiments
: code for runnning experiments (e.g. loading data, training models, evaluating models)scripts
: scripts for running experiments (e.g. python scripts that launch jobs inexperiments
folder with different hyperparams)notebooks
: jupyter notebooks for analyzing results and making figurestests
: unit tests
- clone and run
pip install -e .
, resulting in a package namedtprompt
that can be imported- see
setup.py
for dependencies, not all are required
- see
- example run: run
python scripts/01_train_basic_models.py
(which callsexperiments/01_train_model.py
then view the results innotebooks/01_model_results.ipynb