# Probing

Probing in large language models (LLMs) for linguistic abilities is used to understand how these models encode and process various aspects of language.

There are two main probing paradigms:

1. Diagnostic probing: This examines internal neural representations of the model.

2. Prompting-based probing: This evaluates model outputs through behavioral tests.

## Syntactic probing

We perform syntactic probing by testing the model's ability to handle subject-verb agreement.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

In [None]:
import numpy as np

# Load pre-trained model and tokenizer
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


We go through all combinations of subject and verb pairs, creating test sentences for both singular and plural forms. We then get the probabilities for each verb form.

In [None]:
def syntactic_probe(sentence_template, subject_pairs, verb_pairs):
    results = {}

    for subj_sing, subj_plur in subject_pairs:
        for verb_sing, verb_plur in verb_pairs:
            sent_sing = sentence_template.format(subject=subj_sing, verb=tokenizer.mask_token)
            prob_sing = probe_sentence(sent_sing, [verb_sing, verb_plur])

            sent_plur = sentence_template.format(subject=subj_plur, verb=tokenizer.mask_token)
            prob_plur = probe_sentence(sent_plur, [verb_sing, verb_plur])

            results[(subj_sing, subj_plur, verb_sing, verb_plur)] = {
                'singular': prob_sing,
                'plural': prob_plur
            }

    return results

We take a sentence with a masked verb and calculates the probability of each verb option fitting in that mask.

In [None]:
def probe_sentence(sentence, verb_options):
    inputs = tokenizer(sentence, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
    logits = outputs.logits[0, mask_token_index, :]
    probs = torch.nn.functional.softmax(logits, dim=-1)

    verb_probs = {}
    for verb in verb_options:
        verb_id = tokenizer.convert_tokens_to_ids(verb)
        verb_probs[verb] = probs[0, verb_id].item()

    return verb_probs

We provide a sentence template, pairs of singular/plural subjects, and pairs of singular/plural verbs.

In [None]:
sentence_template = "The {subject} {verb} in the park."
subject_pairs = [
    ("dog", "dogs"),
    ("cat", "cats"),
    ("child", "children")
]
verb_pairs = [
    ("runs", "run"),
    ("plays", "play"),
    ("walks", "walk")
]

results = syntactic_probe(sentence_template, subject_pairs, verb_pairs)

for (subj_sing, subj_plur, verb_sing, verb_plur), probs in results.items():
    print(f"Subject: {subj_sing}/{subj_plur}, Verb: {verb_sing}/{verb_plur}")
    print(f"Singular agreement:")
    print(f"  {verb_sing}: {probs['singular'].get(verb_sing, 'N/A'):.4f}")
    print(f"  {verb_plur}: {probs['singular'].get(verb_plur, 'N/A'):.4f}")
    print(f"Plural agreement:")
    print(f"  {verb_sing}: {probs['plural'].get(verb_sing, 'N/A'):.4f}")
    print(f"  {verb_plur}: {probs['plural'].get(verb_plur, 'N/A'):.4f}")
    print()

correct_predictions = 0
total_predictions = 0

for (subj_sing, subj_plur, verb_sing, verb_plur), probs in results.items():
    if verb_sing in probs['singular'] and verb_plur in probs['singular']:
        if probs['singular'][verb_sing] > probs['singular'][verb_plur]:
            correct_predictions += 1
        total_predictions += 1

    if verb_sing in probs['plural'] and verb_plur in probs['plural']:
        if probs['plural'][verb_plur] > probs['plural'][verb_sing]:
            correct_predictions += 1
        total_predictions += 1

accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
print(f"Overall accuracy: {accuracy:.2f}")



Subject: dog/dogs, Verb: runs/run
Singular agreement:
  runs: 0.0102
  run: 0.0001
Plural agreement:
  runs: 0.0001
  run: 0.0081

Subject: dog/dogs, Verb: plays/play
Singular agreement:
  plays: 0.0034
  play: 0.0001
Plural agreement:
  plays: 0.0001
  play: 0.0279

Subject: dog/dogs, Verb: walks/walk
Singular agreement:
  walks: 0.0165
  walk: 0.0003
Plural agreement:
  walks: 0.0001
  walk: 0.0076

Subject: cat/cats, Verb: runs/run
Singular agreement:
  runs: 0.0086
  run: 0.0002
Plural agreement:
  runs: 0.0001
  run: 0.0057

Subject: cat/cats, Verb: plays/play
Singular agreement:
  plays: 0.0083
  play: 0.0004
Plural agreement:
  plays: 0.0002
  play: 0.0839

Subject: cat/cats, Verb: walks/walk
Singular agreement:
  walks: 0.0260
  walk: 0.0008
Plural agreement:
  walks: 0.0002
  walk: 0.0102

Subject: child/children, Verb: runs/run
Singular agreement:
  runs: 0.0083
  run: 0.0002
Plural agreement:
  runs: 0.0000
  run: 0.0024

Subject: child/children, Verb: plays/play
Singular ag

## Pragmatics probing

Pragmatics probing tests whether an LLM can infer meanings beyond literal interpretations by considering context, social norms, and cultural factors.

In [6]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

model_name = "google/flan-t5-large"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

def probe_pragmatics(context, question, options):
    input_text = f"Context: {context}\n\nQuestion: {question}\nOptions: {', '.join(options)}\n\nAnswer:"
    inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)

    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=20, num_return_sequences=len(options), num_beams=3)

    decoded_outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

    option_probs = {}
    for option, output in zip(options, decoded_outputs):
        option_probs[option] = 1 if option.lower() in output.lower() else 0

    total = sum(option_probs.values())
    if total > 0:
        option_probs = {k: v / total for k, v in option_probs.items()}
    else:
        option_probs = {k: 1 / len(options) for k in options}

    return option_probs

examples = [
    {
        "context": "Alice: Did you enjoy the movie?\nBob: Well, the popcorn was good.",
        "question": "What does Bob mean?",
        "options": [
            "He liked the movie.",
            "He did not enjoy the movie.",
            "He only enjoyed the popcorn."
        ],
    },
    {
        "context": "Teacher: Can you explain why you missed class yesterday?\nStudent: I had a family emergency.",
        "question": "What is the student implying?",
        "options": [
            "The student wants sympathy.",
            "The student is making an excuse.",
            "The student is avoiding answering."
        ],
    },
    {
        "context": "Friend 1: Do you want to go out tonight?\nFriend 2: I have a lot of work to do.",
        "question": "What does Friend 2 mean?",
        "options": [
            "They want to go out later.",
            "They are politely declining.",
            "They are asking for help with work."
        ],
    }
]

for example in examples:
    print(f"Context:\n{example['context']}")
    print(f"Question: {example['question']}")

    results = probe_pragmatics(
        example["context"],
        example["question"],
        example["options"]
    )

    print("Probabilities for each option:")
    for option, prob in sorted(results.items(), key=lambda x: x[1], reverse=True):
        print(f"  {option}: {prob:.4f}")

    print("\n" + "="*50 + "\n")


Context:
Alice: Did you enjoy the movie?
Bob: Well, the popcorn was good.
Question: What does Bob mean?
Probabilities for each option:
  He did not enjoy the movie.: 1.0000
  He liked the movie.: 0.0000
  He only enjoyed the popcorn.: 0.0000


Context:
Teacher: Can you explain why you missed class yesterday?
Student: I had a family emergency.
Question: What is the student implying?
Probabilities for each option:
  The student is avoiding answering.: 1.0000
  The student wants sympathy.: 0.0000
  The student is making an excuse.: 0.0000


Context:
Friend 1: Do you want to go out tonight?
Friend 2: I have a lot of work to do.
Question: What does Friend 2 mean?
Probabilities for each option:
  They want to go out later.: 0.3333
  They are politely declining.: 0.3333
  They are asking for help with work.: 0.3333


