Skip to content

tangg555/story-generation-demo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Story-Generation-Demo

A simple story generation demo which finetines Huggingface pretrained model BART to generate stories.

Instructions

This project is based on pytorch-lightning framework, and the pretrained model BART is downloaded from Hugginface: bart-base.

So if you want to run this code, you must have following preliminaries:

Quick Start

1. Install packages

pip install -r requirements.txt

2. Collect Datasets and Resources

datasets and resources are separate from the code, since they are too large. Both of them can be downloaded from BaiduNetDisk (input code: gb1a) or Dropbox. Put them to the basedir after downloaded. #####2.1 Datasets The structure of datasetsshould be like this:

├── datasets
   └── story-generation		# expeirment group name
       ├── `roc`        # ROCStories
              └── `train.source.txt`    # leading context of stories
              └── `train.target.txt`       # story corresponding to the leading context
              └── `val.source.txt` 
              └── `val.target.txt` 
              └── `test.source.txt` 
              └── `test.target.txt` 

The raw dataset of roc story can be accessed for free. Google and get it. e.g. homepage .

train, val, test are split by the ratio of 0.90, 0.05, 0.05

the example of test.source.txt (leading context):

ken was driving around in the snow .

the example of test.target.txt (story):

he needed to get home from work . he was driving slowly to avoid accidents . unfortunately the roads were too slick and ken lost control . his tires lost traction and he hit a tree . #####2.1 Resources The structure of resources should be like this:

├── resources
   └── external-generation		
       ├── `bart-base`        
              └── `config.json`    
              └── `pytorch_model.bin`       
              └── ...

The huggingface pretrained model bart-base can be downloaded from here

3. Fine-tuning BART on ROCStories

I have set all essential parameters, so you can directly run

python ./tasks/story-generation/train.py

Or

If you want to modify parameters, you can run

python tasks/story-generation/train.py --data_dir=datasets/story-generation/roc-stories\
 --learning_rate=5e-5 \
 --train_batch_size=16 \
 --eval_batch_size=10 \
 --model_name_or_path=resources/external_models/bart-base \
 --output_dir=output/story-generation \
 --model_name leading-bart \
 --experiment_name=leading-bart-roc-stories\
 --val_check_interval=1.0 \
 --limit_val_batches=10 \
 --max_epochs=3 \
 --accum_batches_args=4

4. Generating Stories and Evaluation

Same to training. Directly run

python ./tasks/story-generation/test.py

Or

python tasks/story-generation/test.py --data_dir=datasets/story-generation/roc-stories \
  --eval_batch_size=10 \
  --model_name_or_path=output/story-generation/leading-bart-roc-stories/best_tfmr \
  --output_dir=output/story-generation \
  --model_name leading-bart \
  --experiment_name=leading-bart-roc-stories

Notation

Some notes for this project.

1 - Complete Prject Structure

├── datasets 
├── output  # this will be automatically created to put all the output stuff including checkpoints and generated text
├── resources # put some resources used by the model e.g. the pretrained model.
├── tasks # excute programs e.g. training, tesing, generating stories
├── .gitignore # used by git
├── requirement.txt # the checklist of essential python packages 

2 - Scripts for Downloading huggingface models

I wrote two scripts to download models from huggingface website. One is tasks/download_hf_models.sh, and another is src/utils/huggingface_helper.py

About

A simple story generation demo which finetines Huggingface pretrained model to generate stories.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published