Madry Lab

TRAK: Attributing Model Behavior at Scale

Effective, efficient data attribution for large machine learning models.

Try it

Data attribution that is fast and effective


TRAK is orders of magnitude faster than comparably effective data attribution methods, and orders of magnitude more effective than comparably fast methods.

Quickstart

Installation


For a fast version with custom CUDA code, use

              pip install traker[fast]
            

You will need CUDA toolkit and gcc to compile it. For the version that does not require compilation, use

              pip install traker
            

For more details, visit the installation FAQs.

Basic usage


Below we provide a minimal (pseudo)code example showcasing the basic workflow of getting TRAK scores for a given model-dataset pair. For more in-depth examples, including ready-to-run notebooks, check the tutorials in our docs. Those include how to use TRAK with BERT, how to set up TRAK with SLURM, and more!

First, intialize your model and data loaders you want to score with TRAK. For example,

            from torchvision import models
model = models.resnet18()
checkpoint = model.state_dict()
train_loader = ImageNetDataloader(train=True, ...)

Then initalize the TRAKer class and process (featurize) the train set.

              traker = TRAKer(model=model, task='image_classification', train_set_size=...)

traker.load_checkpoint(ckeckpoint, model_id=0)
for batch in train_loader:
  traker.featurize(batch=batch, num_samples=batch[0].shape[0])
traker.finalize_features()

Finally, get the TRAK scores for your targets, e.g. all ImageNet validation samples:

targets_loader = ImageNetLoader(train=False, ...)
traker.start_scoring_checkpoint('quickstart', ckeckpoint, num_targets=...)
for batch in targets_loader:
    traker.score(batch=batch, num_samples=batch[0].shape[0])
scores = traker.finalize_scores(exp_name='quickstart')

That's it! Now you're have the TRAK scores in a numpy array. Check out the section below for some examples!

Example TRAK scores in language and vision tasks

Microsoft COCO is a dataset containing images of complex everyday scenes described with free-text captions. CLIP is a model that learns visual concept from natural language supervision by embedding the images and captions in a shared latent space.

Below we show the highest scoring image-caption pairs for a few randomly selected targets from the MS COCO test set. Image-caption pairs from the train set with high TRAK scores have a high influence on CLIP "thinking" that the target image and caption should be close in its latent space.

Target caption: a cat is laying on top of a laptop computer

TRAK top scoring train images

TRAK top scoring sample
a cat is laying on top of a laptop computer
TRAK top scoring sample
a cat laying next to an open laptop computer
TRAK top scoring sample
a dog stretched out laying on a persons legs under a laptop
TRAK top scoring sample
the cat is laying on top of the laptop
TRAK top scoring sample
a cat is laying on top of a laptop computer
TRAK top scoring sample
a cat is laying down on a white laptop
Target caption: a close up of a giraffe and a zebra in a field near trees

TRAK top scoring train images

TRAK top scoring sample
a giraffe standing in a field and by trees
TRAK top scoring sample
a close up of a giraffe with trees in the background
TRAK top scoring sample
two giraffes in a grassy field with small trees
TRAK top scoring sample
a giraffe walks in the field with trees and grass
TRAK top scoring sample
a giraffe and zebra together in a field
TRAK top scoring sample
the giraffe and zebra are outside by the trees
Target caption: a table full of bananas being sold outside

TRAK top scoring train images

TRAK top scoring sample
there are many bunches of bananas being sold
TRAK top scoring sample
bananas and apples grouped together to be sold
TRAK top scoring sample
a table topped with lots of ripe bananas
TRAK top scoring sample
a table topped with lots of ripe bananas sitting next to each other
TRAK top scoring sample
several bunches of bananas on a table
TRAK top scoring sample
a large bunch of bananas sitting by some chairs
Target caption: a blue and white bus with two bicycles and people by it

TRAK top scoring train images

TRAK top scoring sample
a blue and white bus parked behind another vehicle
TRAK top scoring sample
a blue and white bus driving through a park next to trees
TRAK top scoring sample
a blue bus is traveling down the road
TRAK top scoring sample
a blue bus driving down a road next to people
TRAK top scoring sample
a blue and white bus is parked by a curb
TRAK top scoring sample
a young woman in a blue dress standing in front of parked bicycles
Target caption: a man riding a motorcycle with a helmet on

TRAK top scoring train images

TRAK top scoring sample
a person wearing a helmet is riding a motorcycle
TRAK top scoring sample
a person in a helmet is riding a motorcycle
TRAK top scoring sample
a person with a helmet is sitting on a motorcycle
TRAK top scoring sample
two people riding a motorcycle down a street
TRAK top scoring sample
a man is riding a very small motorcycle
TRAK top scoring sample
a man riding on the back of a motorcycle down a road
Target caption: a person riding a snowboard down a hill

TRAK top scoring train images

TRAK top scoring sample
a person on a snowboard rides on the hill
TRAK top scoring sample
a man riding a snowboard down a snow covered hill
TRAK top scoring sample
a man riding a snowboard with a backpack down a hill
TRAK top scoring sample
a man riding a snowboard down a hill
TRAK top scoring sample
a man riding a snowboard down a hill to a ramp
TRAK top scoring sample
a person riding a snowboard down a snow covered slope

QNLI is a natural language inference dataset from the GLUE benchmark. It is a binary classification task, where given a question and a sentence, the goal is to predict whether the sentence contains an answer to the question. We finetune BERT-base models on QNLI.

Below we show the highest scoring question-answer pairs for a few randomly selected samples from the QNLI test set. Question-answer pairs with a high TRAK score have a high influence on BERT-base predicting entailment for the target.

Target

Q: How many households has kids under the age of 18 living in them?

A: There were 158,349 households, of which 68,511 (43.3%) had children under the age of 18 living in them, 69,284 (43.8%) were opposite-sex married couples living together, 30,547 (19.3%) had a female householder with no husband present, 11,698 (7.4%) had a male householder with no wife present.

Model prediction:

entailment

TRAK top scoring train samples:

Q: What percent of household have children under 18?

A: There were 46,917 households, out of which 7,835 (16.7%) had children under the age of 18 living in them, 13,092 (27.9%) were opposite-sex married couples living together, 3,510 (7.5%) had a female householder with no husband present, 1,327 (2.8%) had a male householder with no wife present.

Model prediction:

entailment

Q: What percent in the 2000 census had persons under the age of 18?

A: A: There are 44,497 households, out of which 15.8% have children under the age of 18, 27.5% are married couples living together, 7.5% have a female householder with no husband present, and 62.3% are non-families.

Model prediction:

entailment

TRAK bottom scoring train samples:

Q: Roughly how many same-sex couples were there?

A: There were 46,917 households, out of which 7,835 (16.7%) had children under the age of 18 living in them, 13,092 (27.9%) were opposite-sex married couples living together, 3,510 (7.5%) had a female householder with no husband present, 1,327 (2.8%) had a male householder with no wife present.

Model prediction:

no entailment

Q: What percentage of households in Atlantic City were made up of individuals?

A: There were 15,504 households, of which 27.3% had children under the age of 18 living with them, 25.9% were married couples living together, 22.2% had a female householder with no husband present, and 44.8% were non-families.

Model prediction:

no entailment

Target

Q: What is the hottest temperature record for Fresno?

A: The official record high temperature for Fresno is 115 °F (46.1 °C), set on July 8, 1905, while the official record low is 17 °F (−8 °C), set on January 6, 1913.

Model prediction:

entailment

TRAK top scoring train samples:

Q: What is the hottest temperature in Raleigh?

A: Extremes in temperature have ranged from -9 °F (-23 °C) on January 21, 1985 up to 105 °F (41 °C), most recently on July 8, 2012.

Model prediction:

entailment

Q: What day did Charleston's airport hit the coldest day on record?

A: The highest temperature recorded within city limits was 104 °F (40 °C), on June 2, 1985, and June 24, 1944, and the lowest was 7 °F (−14 °C) on February 14, 1899, although at the airport, where official records are kept, the historical range is 105 °F (41 °C) on August 1, 1999 down to 6 °F (−14 °C) on January 21, 1985.

Model prediction:

entailment

TRAK bottom scoring train samples:

Q: What was Tucson's record low?

A: At the University of Arizona, where records have been kept since 1894, the record maximum temperature was 115°.

Model prediction:

no entailment

Q: When does the temperature of morning type young adults reach its lowest?

A: Though variation is great among normal chronotypes, the average human adult's temperature reaches its minimum at about 05:00 (5 a.m.), about two hours before habitual wake time.

Model prediction:

no entailment

Target

Q: What genre of music is Lindisfarne classified as?

A: Lindisfarne are a folk-rock group with a strong Tyneside connection.

Model prediction:

entailment

TRAK top scoring train samples:

Q: What genre of music is featured at Junk?

A: The nightclub, Junk, has been nominated for the UK's best small nightclub, and plays host to a range of dance music's top acts.

Model prediction:

entailment

Q: Which political philosophy does Greece follow?

A: Greece is a democratic and developed country with an advanced high-income economy, a high quality of life and a very high standard of living.

Model prediction:

entailment

TRAK bottom scoring train samples:

Q: Which genre did Madonna started out in?

A: Stephen Thomas Erlewine noted that with her self-titled debut album, Madonna began her career as a disco diva, in an era that did not have any such divas to speak of.

Model prediction:

no entailment

Q: What type of sports centers are Wutaishan Sports Center and Nanjing Olympic Sports Center considered to be?

A: There are two major sports centers in Nanjing, Wutaishan Sports Center and Nanjing Olympic Sports Center.

Model prediction:

no entailment

Target

Q: What can rubisco do by mistake?

A: It can waste up to half the carbon fixed by the Calvin cycle.

Model prediction:

no entailment

TRAK top scoring train samples:

Q: What can clothing provide during hazardous activities?

A: Further, they can provide a hygienic barrier, keeping infectious and toxic materials away from the body.

Model prediction:

no entailment

Q: What usually happens with misdemeanors?

A: These may result in fines and sometimes the loss of one's driver's license, but no jail time.

Model prediction:

no entailment

TRAK bottom scoring train samples:

Q: Quantum Dot LEDs can do what special skill?

A: This allows quantum dot LEDs to create almost any color on the CIE diagram.

Model prediction:

entailment

Q: What does :76 Shadow Copy do?

A: It can only access previous versions of shared files stored on a Windows Server computer.:74 The subsystem on which these components worked, however, is still available for other software to use.:74

Model prediction:

entailment