Skip to content
18 changes: 18 additions & 0 deletions discojs/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type { EventConnection } from './event_connection.js'
import type { Aggregator } from '../aggregator/index.js'
import { EventEmitter } from '../utils/event_emitter.js'
import { type } from "./messages.js";
import { ModelWeightAccess } from "../training/disco.js";

const debug = createDebug("discojs:client");

Expand All @@ -24,6 +25,7 @@ const debug = createDebug("discojs:client");
export abstract class Client<N extends Network> extends EventEmitter<{
status: RoundStatus;
participants: number;
modelSynced: WeightsContainer | undefined;
}> {
// Own ID provided by the network's server.
protected _ownId?: NodeID
Expand All @@ -38,6 +40,9 @@ export abstract class Client<N extends Network> extends EventEmitter<{
*/
protected promiseForMoreParticipants: Promise<void> | undefined = undefined;

// Interface to access trainer's model weights
protected modelWeightAccess?: ModelWeightAccess;

/**
* When the server notifies the client that they can resume training
* after waiting for more participants, we want to be able to display what
Expand All @@ -56,6 +61,15 @@ export abstract class Client<N extends Network> extends EventEmitter<{
) {
super()
}

/**
* Used for decentralized learning.
* Set the interface used by client to access to trainer's model weights.
* Disco object provides this access.
*/
setModelWeightAccess(modelWeightAccess: ModelWeightAccess){
this.modelWeightAccess = modelWeightAccess
}

/**
* Communication callback called at the beginning of every training round.
Expand Down Expand Up @@ -193,6 +207,10 @@ export abstract class Client<N extends Network> extends EventEmitter<{
return await serialization.model.decode(encoded)
}

public finishRound(): void{
// DecentralizedClient override the method to clean up round state
}

/**
* Number of contributors to a collaborative session
* If decentralized, it should be the number of peers
Expand Down
54 changes: 54 additions & 0 deletions discojs/src/client/decentralized/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Peer Connection in Decentralized Learning
This document describes how peer connections for decentralized learning are established.

Peer connections for decentralized learning are coordinated by the server. However, model weight updates are shared only between peers.

## Peer Connection and Round Participation

Clients first connect to the server. The server then coordinates which clients should participate in each decentralized training round.

At the beginning of each round, clients send `JoinRound` messages to the server. After that, they send `PeerIsReady` messages to notify the server that they are ready to establish peer connections and exchange model updates.

For each round, the server assumes that all currently connected clients are participating, except for clients that are still syncing their model after joining in the middle of training. Once the server has received `PeerIsReady` messages from all expected participants for the round, it sends a `PeersForRound` message to each participant. This message tells clients which peers they should connect to for the current round.

After receiving `PeersForRound`, participants establish peer connections and proceed with decentralized learning.

### Connection Ready Check and Signaling Weight Sharing

Before clients start sharing model weights, the server checks that all participants have successfully completed peer connection setup. This prevents faster clients from starting weight sharing while slower clients are still establishing connections.
The process is as follows:
1. After completing peer connections, each client sends a `ConnectionsReady` message to the server.
2. The server counts the number of received `ConnectionsReady` messages.
3. Once the number of `ConnectionsReady` messages matches the number of expected participants for the round, the server sends a `StartWeightSharing` message to all round participants.
4. Clients only start sharing model updates after receiving `StartWeightSharing`.

### Connection Retries and Failed Client Disconnection

Peer connections may fail, so decentralized training includes a retry mechanism. The maximum number of retries is controlled by `maxConnectionRetry`, which is specified in the task training information.

The retry mechanism is triggered when the server times out while waiting for `ConnectionsReady` messages.
The process is as follows:
1. If the number of retries is still below `maxConnectionRetry`, the server sends a `RetryPeerConnection` message to all peers in the current round.
2. When clients receive `RetryPeerConnection`, they clean up their peer pool and aggregator nodes, then rerun the peer connection phase.
3. If the connection setup still fails after more than `maxConnectionRetry` attempts, the server removes the failed peers from the round peer list.
4. The server sends a `ConnectionFail` message to the failed peers.
5. When a client receives `ConnectionFail`, it disconnects from the server.
6. The remaining clients receive `RetryPeerConnection` and retry the peer connection phase without the failed peers.

This allows the remaining participants to continue training even when one or more peers fail to establish connections.

### Model Syncing for Participants Joining in the Middle of Training

Participants that join in the middle of training need to receive the latest model before they can participate in future rounds. In decentralized learning, peers do not send model weights to the server, so the newcomer must request the latest model from an existing peer.

The model syncing process is as follows:
1. When a new participant joins in the middle of training, the server marks the participant as having joined mid-training in `NewDecentralizedNodeInfo`.
2. After receiving `NewDecentralizedNodeInfo`, the new client sets a local flag indicating that it needs model syncing.
3. When training begins, if this flag is set, the new client sends a `ModelSyncRequest` message to the server.
4. After receiving `ModelSyncRequest`, the server sends messages as step 5 and 6, using selected model provider information from previous training round.
5. The server sends `SignalModelProvider` to the new participant with information about the provider peer.
6. The server sends `SignalNewPeer` to the provider peer with information about the newly joined peer.
7. Using this signaling information, the new participant and provider peer establish a peer connection.
8. The provider waits until the current aggregation round has finished, then sends the latest model to the new participant using a `SharedModel` message.
9. The new participant receives the model and updates its local model weights.
10. After syncing, the new participant can join subsequent decentralized training rounds.
179 changes: 170 additions & 9 deletions discojs/src/client/decentralized/decentralized_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import type { DataType, Model, WeightsContainer } from "../../index.js";
import { serialization } from "../../index.js";
import { Client, shortenId } from '../client.js'
import { type NodeID } from '../index.js'
import { type, type ClientConnected } from '../messages.js'
import { type, type ClientConnected, NarrowMessage } from '../messages.js'
import { timeout } from '../utils.js'
import { WebSocketServer, waitMessage, type PeerConnection, waitMessageWithTimeout } from '../event_connection.js'
import { PeerPool } from './peer_pool.js'
Expand All @@ -26,6 +26,14 @@ export class DecentralizedClient extends Client<"decentralized"> {
#pool?: PeerPool
#connections?: Map<NodeID, PeerConnection>

// Flag if this model requires model synchronization
#modelSyncNeeded?: boolean

// Check if the training round is in progress
// Used to get the latest model for model synchronization
#roundFinishedPromise?: Promise<void>
#resolveRoundFinished?: () => void // contains resolver

// Used to handle timeouts and promise resolving after calling disconnect
private get isDisconnected() : boolean {
return this._server === undefined
Expand All @@ -36,6 +44,24 @@ export class DecentralizedClient extends Client<"decentralized"> {
// Emits the `participants` event
this.nbOfParticipants = this.aggregator.nodes.size === 0 ? 1 : this.aggregator.nodes.size
}

// Used by model provider peer during model syncing
private async handleSignalNewPeer(event: NarrowMessage<type.SignalNewPeer>): Promise<void> {
if (this.#pool === undefined) throw new Error('received signal about new peer but peer pool is undefined')
Comment thread
ahzero7d1 marked this conversation as resolved.
Outdated
const roundFinishedPromise = this.#roundFinishedPromise
const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{})

const newcomerConn = syncConnection.get(event.newNode)

if (newcomerConn === undefined){
// if connection with newly joining client fails, print debug message
// and return
debug(`Cannot connect to newly joined client [${event.newNode}]`)
return
}

await this.sendModel(newcomerConn, roundFinishedPromise)
}

/**
* Public method called by disco.ts when starting training. This method sends
Expand Down Expand Up @@ -69,6 +95,13 @@ export class DecentralizedClient extends Client<"decentralized"> {
this.#pool.signal(event.peer, event.signal)
})

// Listen if the client is selected as a model provider node for a newly joining client.
// Upon receiving the signal, this client establishes a connection with the newcomer
// and sends the latest model weights.
this.server.on(type.SignalNewPeer, (event) => {
void this.handleSignalNewPeer(event)
})

// c.f. setupServerCallbacks doc for explanation
let receivedEnoughParticipants = false
this.setupServerCallbacks(() => receivedEnoughParticipants = true)
Expand All @@ -79,8 +112,9 @@ export class DecentralizedClient extends Client<"decentralized"> {
this.server.send(msg)

const { id, waitForMoreParticipants,
nbOfParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo)

nbOfParticipants, joinedMidTraining } = await waitMessage(this.server, type.NewDecentralizedNodeInfo)

this.#modelSyncNeeded = joinedMidTraining
this.nbOfParticipants = nbOfParticipants


Expand Down Expand Up @@ -129,14 +163,50 @@ export class DecentralizedClient extends Client<"decentralized"> {
* When connected, one peer creates a promise for every other peer's weight update
* and waits for it to resolve.
*
* If a client joined the training after the first round,
* model syncing happens first to get the latest model.
*/
override async onRoundBeginCommunication(): Promise<void> {
if (this.#modelSyncNeeded) {
// 1. If model sync is needed, send server a request
this.server.send({ type: type.ModelSyncRequest })

// 2. Get the provider information from the server
const providerInfo = await waitMessageWithTimeout(this.server, type.SignalModelProvider, 30_000, "Timeout while waiting for the latest model provider")

if (this.#pool === undefined) {
throw new Error('peer pool is undefined, make sure to call `client.connect()` first')
}

// 3. Connect with model provider client and get the latest model
const syncConnection = await this.#pool.getPeers(
Set([providerInfo.providerNode]),
this.server,
()=>{}
)
Comment on lines +184 to +188

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling getPeers adds the providerNode to the peer pool (this line) which is then never removed and unused in aggregation either.
I don't think that's a big deal as I assume this model sync won't happen that often but worth adding a comment to help debug if it becomes problematic

const providerConn = syncConnection.get(providerInfo.providerNode)

if (providerConn === undefined){
throw new Error("The latest model provider is not connected")
}

const latestModel = await this.receiveModel(providerConn)
this.modelWeightAccess?.setModelWeight(latestModel)

this.emit("modelSynced", this.modelWeightAccess?.getModelWeight())
this.#modelSyncNeeded = false
}

// Notify the server we want to join the next round so that the server
// waits for us to be ready before sending the list of peers for the round
this.server.send({ type: type.JoinRound })
// Store the promise for the current round's aggregation result.
// We will await for it to resolve at the end of the round when exchanging weight updates.
this.aggregationResult = this.aggregator.getPromiseForAggregation()

// Do not proceed to local training when minNbOfParticipants condition is not satisfied
await this.waitForParticipantsIfNeeded()

this.saveAndEmit("local training")
return Promise.resolve()
}
Expand All @@ -149,11 +219,55 @@ export class DecentralizedClient extends Client<"decentralized"> {
// Once enough new participants join we can display the previous status again
this.saveAndEmit("connecting to peers")
// First we check if we are waiting for more participants before sending our weight update
await this.waitForParticipantsIfNeeded()
// Create peer-to-peer connections with all peers for the round
await this.establishPeerConnections()

while(true){
// Wait until enough participants are available before continuing the round
// Checks minNbOfParticipants requirement for
// when participants disconnect when connection error happens continuously
await this.waitForParticipantsIfNeeded()

// Create peer-to-peer connections with all peers for the round
await this.establishPeerConnections()

// Wait for connection related messages from the server before exchanging weight updates
// (1) If the client receives a StartWeightSharing message, it proceeds to weight update exchange
// (2) If it receives a RetryPeerConnections message, it retries peer connection establishment
// (3) After multiple retires, if the connection is still unsuccessful, the server starts excluding nodes from the round
// and sends a ConnectionFail message to those nodes
// (4) Upon receiving ConnectionFail, the client disconnects from the server
const msg = await Promise.race([
waitMessage(this.server, type.StartWeightSharing),
waitMessage(this.server, type.RetryPeerConnections),
waitMessage(this.server, type.ConnectionFail),
])
Comment on lines +240 to +244

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates two event listeners that are not resolved per round and accumulates throughout the rounds. I don't think it's a huge deal but can you add a comment/TODO to note this?


if (msg.type === type.StartWeightSharing){
// Generate a promise that resolves when round training finishes
if (this.#roundFinishedPromise === undefined){
this.#roundFinishedPromise = new Promise<void>((resolve) => {
this.#resolveRoundFinished = resolve
})
}
break
} else if (msg.type === type.RetryPeerConnections){
debug(`[${shortenId(this.ownId)}] retrying peer connection establishment`)
// clear the communication round peer pool
await this.#pool?.shutdown()
this.#pool = new PeerPool(this.ownId)
// clear the connections
this.#connections = Map()
this.setAggregatorNodes(Set(this.ownId))
continue
} else if (msg.type === type.ConnectionFail){
debug(`[${shortenId(this.ownId)}] disconnect from the server`)
await this.disconnect()
throw new Error("Client disconnected after connection failure")
}
}
// Exchange weight updates with peers and return aggregated weights
return await this.exchangeWeightUpdates(weights)
const aggregatedWeight = await this.exchangeWeightUpdates(weights)
Comment on lines +254 to +270

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this scenario ever covered in the unit tests?


return aggregatedWeight
}

/**
Expand All @@ -178,8 +292,9 @@ export class DecentralizedClient extends Client<"decentralized"> {
try {
debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`);
const receivedMessage = await waitMessage(this.server, type.PeersForRound)

const peers = Set(receivedMessage.peers)
debug(`[${shortenId(this.ownId)}] received peer list: %o`, peers.toArray());

if (this.ownId !== undefined && peers.has(this.ownId)) {
throw new Error('received peer list contains our own id')
Expand All @@ -198,7 +313,9 @@ export class DecentralizedClient extends Client<"decentralized"> {
(conn) => this.receivePayloads(conn)
)

debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS());
// Signal server that all connections with other peers in the round are established
this.server.send({ type: type.ConnectionsReady });
debug(`[${shortenId(this.ownId)}] peer connections ready: %o`, connections.keySeq().toJS());
this.#connections = connections
} catch (e) {
debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e);
Expand Down Expand Up @@ -303,4 +420,48 @@ export class DecentralizedClient extends Client<"decentralized"> {
}
return await this.aggregationResult
}

/**
* Receive model from the model provider.
*/
private async receiveModel(providerConn: PeerConnection): Promise<WeightsContainer>{
const message = await waitMessageWithTimeout(providerConn, type.SharedModel, 30_000, "Timeout while waiting for the latest model")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make this timeout value a task parameter, depending on the model size it can make sense to increase the timeout


const decoded = serialization.weights.decode(message.model)
return decoded
}

/**
* Send the latest available model to a newly joining client.
* If the current training round is in progress, wait until the round finishes
* and receive the latest aggregated model.
*/
private async sendModel(newcomerConn: PeerConnection, roundFinishedPromise: Promise<void> | undefined): Promise<void> {
// wait until the round finishes to get the latest model
if (roundFinishedPromise !== undefined){
await roundFinishedPromise
}

const model = this.modelWeightAccess?.getModelWeight()

if (model === undefined){
debug("Failed to get the latest model from model provider client")
return
}
const encoded = await serialization.weights.encode(model)

const message: messages.SharedModel = {
type: type.SharedModel,
model: encoded
}
newcomerConn.send(message)
}

// Resolve the round finished promise and reset related state
override finishRound(): void{
// Mark round as finished so that model synchronization can proceed
this.#resolveRoundFinished?.()
this.#roundFinishedPromise = undefined
this.#resolveRoundFinished = undefined
}
}
Loading
Loading