Learning Equilibria by Gradient Descent

Games
Exposition
Published

August 21, 2025

Albert Gleizes. *La Chasse*. (1911)

Introduction

Given a set of agents playing a game, how do we determine their optimal strategic behavior?

In the last post we looked at ways to differentiably identify equivalence classes of games. In this short post, we’ll use gradient descent to identify the Nash equilibria for some simple games.

Background

Identifying Nash equilibria (or other strategic behavior) is difficult in general1.

At some point I will get into traditional algorithms like support enumeration or Lemke-Howson for finding equilibria. However, for the purposes of this post I will investigate composing parametrized agents using games, and learning the Nash equilibria via gradient descent.

Implementation

We’ll build some simple classes to implement this experiment.

Game

First, we need some game representation.

class Game:
    def __init__(self, payoffs: torch.Tensor | nn.Parameter, actions: List[List[str]], name=None):
        self.num_players = payoffs.shape[0]
        self.payoffs = payoffs
        self.actions = actions
        self.name = name or "Unnamed Game"
        self._size = payoffs.shape 

    @property
    def size(self):
        return self._size
    
    def payoff(self, action_indices):
        return self.payoffs[(slice(None),) + tuple(action_indices)]
    
    def to(self, device):
        return Game(self.payoffs.to(device), self.actions, self.name) 

    def clone(self):
        if isinstance(self.payoffs, nn.Parameter):
            return Game(nn.Parameter(self.payoffs.detach().clone()), self.actions, self.name)
        else:
            return Game(self.payoffs.detach().clone(), self.actions, self.name)

    def __repr__(self):
        learnable = isinstance(self.payoffs, nn.Parameter) and self.payoffs.requires_grad
        return f'<Game "{self.name}" size={self._size} learnable={learnable}>'

Here, the payoffs are either raw tensors or learnable (if you pass in parameters). Parametrized payoffs are useful for tasks like optimizing welfare, mechanism design, inverse RL, etc.

Agent

Next, we need agents that can play the game.

class Agent:
    def __init__(self, policy, name: str):
        self.name = name
        self.policy = policy

    def act(self, actions: List[str]):
        probs = self.policy(actions)
        dist = torch.distributions.Categorical(probs)
        action_idx = dist.sample().item()
        return action_idx, actions[action_idx]

An agent just samples from the policy distribution and takes an action.

Policies

What kind of policies might the agent have?

Abstract

We’ll start with the abstraction. We have a _run_once helper to prevent double initializing a policy.

def _run_once(method):
    attr_flag = f"__{method.__name__}_has_run"

    @functools.wraps(method)
    def wrapper(self, *args, **kwargs):
        if getattr(self, attr_flag, False):
            return                         
        setattr(self, attr_flag, True)
        return method(self, *args, **kwargs)

    return wrapper
class Policy(ABC):
    def __init__(self):
        pass

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)

        # If the subclass overrides 'initialize', wrap it exactly once.
        if "initialize" in cls.__dict__:
            cls.initialize = _run_once(cls.__dict__["initialize"])

    @abstractmethod
    def initialize(self, actions: List[str]) -> None:
        pass
    
    @abstractmethod
    def forward(self, actions: List[str]) -> torch.Tensor:
        pass
    
    def __call__(self, actions: List[str]) -> torch.Tensor:
        if hasattr(self, "initialize"):
            self.initialize(actions)
        return super().__call__(actions)     

Logits

Our first policy is just a logits policy.

class LogitsPolicy(Policy, nn.Module):
    def __init__(self, initialization='uniform'):
        Policy.__init__(self)
        nn.Module.__init__(self)
        self.initialization = initialization
        self.logits = None  
    
    def initialize(self, actions: List[str]) -> None:
        num_actions = len(actions)
        if self.initialization == 'uniform':
            self.logits = nn.Parameter(torch.zeros(num_actions))  
        elif self.initialization == 'random':
            self.logits = nn.Parameter(torch.rand(num_actions))

    def forward(self, actions: List[str]) -> torch.Tensor:
        return torch.softmax(self.logits, dim=0)

We’ll hold off on other policies until later posts.

Arena

Let’s compose Agents and Games in “Arena” objects. This is just to cleanly separate Agents from Games.

class Arena:

    def __init__(self, game: Game, agents: List[Agent]):
        assert len(agents) == game.num_players
        self.game = game
        self.agents = agents

        for i, agent in enumerate(self.agents):
            if hasattr(agent.policy, 'initialize'):
                agent.policy.initialize(self.game.actions[i])

    def play(self):
        action_indices = []
        actions_chosen = []

        for player_idx, agent in enumerate(self.agents):
            actions = self.game.actions[player_idx]
            action_idx, action = agent.act(actions)
            action_indices.append(action_idx)
            actions_chosen.append(action)

        payoffs = self.game.payoff(action_indices)
        return actions_chosen, payoffs

    def expected_payoffs(self):
        dists = [agent.policy(self.game.actions[i]) for i, agent in enumerate(self.agents)]
        joint_dist = dists[0]
        for dist in dists[1:]:
            joint_dist = torch.einsum('i,j->ij', joint_dist.flatten(), dist).flatten()

        payoffs_flat = self.game.payoffs.view(self.game.num_players, -1)
        exp_payoffs = (joint_dist * payoffs_flat).sum(-1)
        return exp_payoffs

The “play” function runs a round of the game. The “expected payoffs” returns the expected payoffs for each agent.

Example

Let’s look at an example. This first example is Stag Hunt.

if __name__ == "__main__":
    staghunt_actions = [["Stag", "Hare"], ["Stag", "Hare"]]
    p1_payoffs = [
        [8, 0],  # P1 plays Stag vs P2's [Stag, Hare]
        [2, 3]   # P1 plays Hare vs P2's [Stag, Hare]
    ]

    p2_payoffs = [
        [3, 1],  # P2 plays Stag vs P1's [Stag, Hare]  
        [0, 2]   # P2 plays Hare vs P1's [Stag, Hare]
    ] 
    payoffs =  nn.Parameter(torch.tensor([p1_payoffs, p2_payoffs], dtype=torch.float))
    stag_hunt = Game(payoffs, staghunt_actions, "Stag Hunt")
    print(stag_hunt)
    
    alice = Agent(
        policy=lambda actions: torch.tensor([1.0 if a=="Stag" else 0.0 for a in actions]), 
        name="Alice"
    )
    bob = Agent(
        policy=lambda actions: torch.ones(len(actions))/len(actions), 
        name="Bob"
    )
    
    stag_hunt_arena = Arena(stag_hunt, [alice, bob])
    
    print(stag_hunt_arena.expected_payoffs())
    print(stag_hunt_arena.play())
    print(stag_hunt_arena.play())

We initialize two deterministic agents, Alice and Bob. Alice always plays Stag. Bob plays are random. We then compute their expected payoffs (and play two rounds).

<Game "Stag Hunt" size=torch.Size([2, 2, 2]) learnable=True>
tensor([4., 2.], grad_fn=<SumBackward1>)
(['Stag', 'Hare'], tensor([0., 1.], grad_fn=<SelectBackward0>))
(['Stag', 'Stag'], tensor([8., 3.], grad_fn=<SelectBackward0>))

We can see the expected payoff for Alice is 4 and the expected payoff for Bob is 2. In the two rounds they play, Bob first plays “Hare” (low payoffs for both players), then plays “Stag” (high payoffs).

Now let’s build differentiable agents:

...
    diff_alice = Agent(
        policy=LogitsPolicy(initialization='random'),
        name="DiffAlice"
    )
    diff_bob = Agent(
        policy=LogitsPolicy(initialization='random'),
        name="DiffBob"
    )
    diff_agents = [diff_alice, diff_bob]
    diff_stag_hunt_arena = Arena(stag_hunt, diff_agents)
    
    optimizers = [
        optim.Adam(diff_alice.policy.parameters(), lr=0.1),
        optim.Adam(diff_bob.policy.parameters(), lr=0.1)
    ]

    # Training loop
    for step in range(200):
        exp_payoffs = diff_stag_hunt_arena.expected_payoffs()
        
        # Player 1 update
        optimizers[0].zero_grad()
        (-exp_payoffs[0]).backward(retain_graph=True)  
        optimizers[0].step()

        # Player 2 update
        optimizers[1].zero_grad()
        (-exp_payoffs[1]).backward()
        optimizers[1].step()

        if step % 20 == 0:
            print(f"Step {step}, Expected Payoffs: {exp_payoffs.detach().cpu().numpy()}")

    # Final Policies
    for i, agent in enumerate(diff_agents):
        logits = agent.policy.logits.detach().numpy()
        probs = agent.policy(stag_hunt.actions[i]).detach().numpy()
        print(f"Agent {i+1} final logits: {logits}")
        print(f"Agent {i+1} final probabilities: {probs}")
        print(f"Agent {i+1} prefers: {'Stag' if probs[0] > probs[1] else 'Hare'}")
        print()

In the main loop, we optimize Alice and Bob’s policies separately via simultaneous updates (there are other choices, like alternative updates). We see:

Step 0, Expected Payoffs: [2.874151  1.4458582]
Step 20, Expected Payoffs: [7.565231 2.889052]
Step 40, Expected Payoffs: [7.9332013 2.9811559]
Step 60, Expected Payoffs: [7.9630227 2.9893801]
Step 80, Expected Payoffs: [7.972061 2.991976]
Step 100, Expected Payoffs: [7.9772897 2.9934921]
Step 120, Expected Payoffs: [7.9810405 2.994578 ]
Step 140, Expected Payoffs: [7.9839015 2.995403 ]
Step 160, Expected Payoffs: [7.986139  2.9960463]
Step 180, Expected Payoffs: [7.9879236 2.9965587]
Agent 1 final logits: [ 4.3300014 -2.8580525]
Agent 1 final probabilities: [9.9924505e-01 7.5498747e-04]
Agent 1 prefers: Stag

Agent 2 final logits: [ 3.8065035 -3.3711162]
Agent 2 final probabilities: [9.9923706e-01 7.6290034e-04]
Agent 2 prefers: Stag

which is indeed the Nash equilibrium2.

Conclusion

Our differentiable agents successfully discovered that mutual cooperation (both playing Stag) is the Nash equilibrium in the Stag Hunt game. This approach can scale to continuous action spaces and handle n-player games, and this framework is also compositional (we can swap in different policy architectures, loss functions, or optimization algorithms to explore different solution concepts or learning dynamics)3.

Gradient descent doesn’t guarantee convergence to Nash equilibria in all games. Zero-sum games may cycle, games with multiple equilibria depend on initialization, and simultaneous updates can lead to instability.

Footnotes

  1. PPAD-complete. See here or here.↩︎

  2. One of them. Running over and over again you can also see the game converge to [Hare, Hare].↩︎

  3. In a future post we will hopefully take composition further, and compose game inputs/outputs.↩︎