-
Notifications
You must be signed in to change notification settings - Fork 32
Fix decentralized clients connection issue #1110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 9 commits
2900c65
dd22a7d
7d0e634
947d2c6
18e777e
b443f53
9c7c488
09ced04
14b2a4b
3bfce91
62196b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| # Peer Connection in Decentralized Learning | ||
| This document describes how peer connections for decentralized learning are established. | ||
|
|
||
| Peer connections for decentralized learning are coordinated by the server. However, model weight updates are shared only between peers. | ||
|
|
||
| ## Peer Connection and Round Participation | ||
|
|
||
| Clients first connect to the server. The server then coordinates which clients should participate in each decentralized training round. | ||
|
|
||
| At the beginning of each round, clients send `JoinRound` messages to the server. After that, they send `PeerIsReady` messages to notify the server that they are ready to establish peer connections and exchange model updates. | ||
|
|
||
| For each round, the server assumes that all currently connected clients are participating, except for clients that are still syncing their model after joining in the middle of training. Once the server has received `PeerIsReady` messages from all expected participants for the round, it sends a `PeersForRound` message to each participant. This message tells clients which peers they should connect to for the current round. | ||
|
|
||
| After receiving `PeersForRound`, participants establish peer connections and proceed with decentralized learning. | ||
|
|
||
| ### Connection Ready Check and Signaling Weight Sharing | ||
|
|
||
| Before clients start sharing model weights, the server checks that all participants have successfully completed peer connection setup. This prevents faster clients from starting weight sharing while slower clients are still establishing connections. | ||
| The process is as follows: | ||
| 1. After completing peer connections, each client sends a `ConnectionsReady` message to the server. | ||
| 2. The server counts the number of received `ConnectionsReady` messages. | ||
| 3. Once the number of `ConnectionsReady` messages matches the number of expected participants for the round, the server sends a `StartWeightSharing` message to all round participants. | ||
| 4. Clients only start sharing model updates after receiving `StartWeightSharing`. | ||
|
|
||
| ### Connection Retries and Failed Client Disconnection | ||
|
|
||
| Peer connections may fail, so decentralized training includes a retry mechanism. The maximum number of retries is controlled by `maxConnectionRetry`, which is specified in the task training information. | ||
|
|
||
| The retry mechanism is triggered when the server times out while waiting for `ConnectionsReady` messages. | ||
| The process is as follows: | ||
| 1. If the number of retries is still below `maxConnectionRetry`, the server sends a `RetryPeerConnection` message to all peers in the current round. | ||
| 2. When clients receive `RetryPeerConnection`, they clean up their peer pool and aggregator nodes, then rerun the peer connection phase. | ||
| 3. If the connection setup still fails after more than `maxConnectionRetry` attempts, the server removes the failed peers from the round peer list. | ||
| 4. The server sends a `ConnectionFail` message to the failed peers. | ||
| 5. When a client receives `ConnectionFail`, it disconnects from the server. | ||
| 6. The remaining clients receive `RetryPeerConnection` and retry the peer connection phase without the failed peers. | ||
|
|
||
| This allows the remaining participants to continue training even when one or more peers fail to establish connections. | ||
|
|
||
| ### Model Syncing for Participants Joining in the Middle of Training | ||
|
|
||
| Participants that join in the middle of training need to receive the latest model before they can participate in future rounds. In decentralized learning, peers do not send model weights to the server, so the newcomer must request the latest model from an existing peer. | ||
|
|
||
| The model syncing process is as follows: | ||
| 1. When a new participant joins in the middle of training, the server marks the participant as having joined mid-training in `NewDecentralizedNodeInfo`. | ||
| 2. After receiving `NewDecentralizedNodeInfo`, the new client sets a local flag indicating that it needs model syncing. | ||
| 3. When training begins, if this flag is set, the new client sends a `ModelSyncRequest` message to the server. | ||
| 4. After receiving `ModelSyncRequest`, the server sends messages as step 5 and 6, using selected model provider information from previous training round. | ||
| 5. The server sends `SignalModelProvider` to the new participant with information about the provider peer. | ||
| 6. The server sends `SignalNewPeer` to the provider peer with information about the newly joined peer. | ||
| 7. Using this signaling information, the new participant and provider peer establish a peer connection. | ||
| 8. The provider waits until the current aggregation round has finished, then sends the latest model to the new participant using a `SharedModel` message. | ||
| 9. The new participant receives the model and updates its local model weights. | ||
| 10. After syncing, the new participant can join subsequent decentralized training rounds. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ import type { DataType, Model, WeightsContainer } from "../../index.js"; | |
| import { serialization } from "../../index.js"; | ||
| import { Client, shortenId } from '../client.js' | ||
| import { type NodeID } from '../index.js' | ||
| import { type, type ClientConnected } from '../messages.js' | ||
| import { type, type ClientConnected, NarrowMessage } from '../messages.js' | ||
| import { timeout } from '../utils.js' | ||
| import { WebSocketServer, waitMessage, type PeerConnection, waitMessageWithTimeout } from '../event_connection.js' | ||
| import { PeerPool } from './peer_pool.js' | ||
|
|
@@ -26,6 +26,14 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| #pool?: PeerPool | ||
| #connections?: Map<NodeID, PeerConnection> | ||
|
|
||
| // Flag if this model requires model synchronization | ||
| #modelSyncNeeded?: boolean | ||
|
|
||
| // Check if the training round is in progress | ||
| // Used to get the latest model for model synchronization | ||
| #roundFinishedPromise?: Promise<void> | ||
| #resolveRoundFinished?: () => void // contains resolver | ||
|
|
||
| // Used to handle timeouts and promise resolving after calling disconnect | ||
| private get isDisconnected() : boolean { | ||
| return this._server === undefined | ||
|
|
@@ -36,6 +44,24 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| // Emits the `participants` event | ||
| this.nbOfParticipants = this.aggregator.nodes.size === 0 ? 1 : this.aggregator.nodes.size | ||
| } | ||
|
|
||
| // Used by model provider peer during model syncing | ||
| private async handleSignalNewPeer(event: NarrowMessage<type.SignalNewPeer>): Promise<void> { | ||
| if (this.#pool === undefined) throw new Error('received signal about new peer but peer pool is undefined') | ||
| const roundFinishedPromise = this.#roundFinishedPromise | ||
| const syncConnection = await this.#pool.getPeers(Set([event.newNode]), this.server, ()=>{}) | ||
|
|
||
| const newcomerConn = syncConnection.get(event.newNode) | ||
|
|
||
| if (newcomerConn === undefined){ | ||
| // if connection with newly joining client fails, print debug message | ||
| // and return | ||
| debug(`Cannot connect to newly joined client [${event.newNode}]`) | ||
| return | ||
| } | ||
|
|
||
| await this.sendModel(newcomerConn, roundFinishedPromise) | ||
| } | ||
|
|
||
| /** | ||
| * Public method called by disco.ts when starting training. This method sends | ||
|
|
@@ -69,6 +95,13 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| this.#pool.signal(event.peer, event.signal) | ||
| }) | ||
|
|
||
| // Listen if the client is selected as a model provider node for a newly joining client. | ||
| // Upon receiving the signal, this client establishes a connection with the newcomer | ||
| // and sends the latest model weights. | ||
| this.server.on(type.SignalNewPeer, (event) => { | ||
| void this.handleSignalNewPeer(event) | ||
| }) | ||
|
|
||
| // c.f. setupServerCallbacks doc for explanation | ||
| let receivedEnoughParticipants = false | ||
| this.setupServerCallbacks(() => receivedEnoughParticipants = true) | ||
|
|
@@ -79,8 +112,9 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| this.server.send(msg) | ||
|
|
||
| const { id, waitForMoreParticipants, | ||
| nbOfParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) | ||
|
|
||
| nbOfParticipants, joinedMidTraining } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) | ||
|
|
||
| this.#modelSyncNeeded = joinedMidTraining | ||
| this.nbOfParticipants = nbOfParticipants | ||
|
|
||
|
|
||
|
|
@@ -129,14 +163,50 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| * When connected, one peer creates a promise for every other peer's weight update | ||
| * and waits for it to resolve. | ||
| * | ||
| * If a client joined the training after the first round, | ||
| * model syncing happens first to get the latest model. | ||
| */ | ||
| override async onRoundBeginCommunication(): Promise<void> { | ||
| if (this.#modelSyncNeeded) { | ||
| // 1. If model sync is needed, send server a request | ||
| this.server.send({ type: type.ModelSyncRequest }) | ||
|
|
||
| // 2. Get the provider information from the server | ||
| const providerInfo = await waitMessageWithTimeout(this.server, type.SignalModelProvider, 30_000, "Timeout while waiting for the latest model provider") | ||
|
|
||
| if (this.#pool === undefined) { | ||
| throw new Error('peer pool is undefined, make sure to call `client.connect()` first') | ||
| } | ||
|
|
||
| // 3. Connect with model provider client and get the latest model | ||
| const syncConnection = await this.#pool.getPeers( | ||
| Set([providerInfo.providerNode]), | ||
| this.server, | ||
| ()=>{} | ||
| ) | ||
|
Comment on lines
+184
to
+188
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Calling |
||
| const providerConn = syncConnection.get(providerInfo.providerNode) | ||
|
|
||
| if (providerConn === undefined){ | ||
| throw new Error("The latest model provider is not connected") | ||
| } | ||
|
|
||
| const latestModel = await this.receiveModel(providerConn) | ||
| this.modelWeightAccess?.setModelWeight(latestModel) | ||
|
|
||
| this.emit("modelSynced", this.modelWeightAccess?.getModelWeight()) | ||
| this.#modelSyncNeeded = false | ||
| } | ||
|
|
||
| // Notify the server we want to join the next round so that the server | ||
| // waits for us to be ready before sending the list of peers for the round | ||
| this.server.send({ type: type.JoinRound }) | ||
| // Store the promise for the current round's aggregation result. | ||
| // We will await for it to resolve at the end of the round when exchanging weight updates. | ||
| this.aggregationResult = this.aggregator.getPromiseForAggregation() | ||
|
|
||
| // Do not proceed to local training when minNbOfParticipants condition is not satisfied | ||
| await this.waitForParticipantsIfNeeded() | ||
|
|
||
| this.saveAndEmit("local training") | ||
| return Promise.resolve() | ||
| } | ||
|
|
@@ -149,11 +219,55 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| // Once enough new participants join we can display the previous status again | ||
| this.saveAndEmit("connecting to peers") | ||
| // First we check if we are waiting for more participants before sending our weight update | ||
| await this.waitForParticipantsIfNeeded() | ||
| // Create peer-to-peer connections with all peers for the round | ||
| await this.establishPeerConnections() | ||
|
|
||
| while(true){ | ||
| // Wait until enough participants are available before continuing the round | ||
| // Checks minNbOfParticipants requirement for | ||
| // when participants disconnect when connection error happens continuously | ||
| await this.waitForParticipantsIfNeeded() | ||
|
|
||
| // Create peer-to-peer connections with all peers for the round | ||
| await this.establishPeerConnections() | ||
|
|
||
| // Wait for connection related messages from the server before exchanging weight updates | ||
| // (1) If the client receives a StartWeightSharing message, it proceeds to weight update exchange | ||
| // (2) If it receives a RetryPeerConnections message, it retries peer connection establishment | ||
| // (3) After multiple retires, if the connection is still unsuccessful, the server starts excluding nodes from the round | ||
| // and sends a ConnectionFail message to those nodes | ||
| // (4) Upon receiving ConnectionFail, the client disconnects from the server | ||
| const msg = await Promise.race([ | ||
| waitMessage(this.server, type.StartWeightSharing), | ||
| waitMessage(this.server, type.RetryPeerConnections), | ||
| waitMessage(this.server, type.ConnectionFail), | ||
| ]) | ||
|
Comment on lines
+240
to
+244
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This creates two event listeners that are not resolved per round and accumulates throughout the rounds. I don't think it's a huge deal but can you add a comment/TODO to note this? |
||
|
|
||
| if (msg.type === type.StartWeightSharing){ | ||
| // Generate a promise that resolves when round training finishes | ||
| if (this.#roundFinishedPromise === undefined){ | ||
| this.#roundFinishedPromise = new Promise<void>((resolve) => { | ||
| this.#resolveRoundFinished = resolve | ||
| }) | ||
| } | ||
| break | ||
| } else if (msg.type === type.RetryPeerConnections){ | ||
| debug(`[${shortenId(this.ownId)}] retrying peer connection establishment`) | ||
| // clear the communication round peer pool | ||
| await this.#pool?.shutdown() | ||
| this.#pool = new PeerPool(this.ownId) | ||
| // clear the connections | ||
| this.#connections = Map() | ||
| this.setAggregatorNodes(Set(this.ownId)) | ||
| continue | ||
| } else if (msg.type === type.ConnectionFail){ | ||
| debug(`[${shortenId(this.ownId)}] disconnect from the server`) | ||
| await this.disconnect() | ||
| throw new Error("Client disconnected after connection failure") | ||
| } | ||
| } | ||
| // Exchange weight updates with peers and return aggregated weights | ||
| return await this.exchangeWeightUpdates(weights) | ||
| const aggregatedWeight = await this.exchangeWeightUpdates(weights) | ||
|
Comment on lines
+254
to
+270
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this scenario ever covered in the unit tests? |
||
|
|
||
| return aggregatedWeight | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -178,8 +292,9 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| try { | ||
| debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`); | ||
| const receivedMessage = await waitMessage(this.server, type.PeersForRound) | ||
|
|
||
| const peers = Set(receivedMessage.peers) | ||
| debug(`[${shortenId(this.ownId)}] received peer list: %o`, peers.toArray()); | ||
|
|
||
| if (this.ownId !== undefined && peers.has(this.ownId)) { | ||
| throw new Error('received peer list contains our own id') | ||
|
|
@@ -198,7 +313,9 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| (conn) => this.receivePayloads(conn) | ||
| ) | ||
|
|
||
| debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS()); | ||
| // Signal server that all connections with other peers in the round are established | ||
| this.server.send({ type: type.ConnectionsReady }); | ||
| debug(`[${shortenId(this.ownId)}] peer connections ready: %o`, connections.keySeq().toJS()); | ||
| this.#connections = connections | ||
| } catch (e) { | ||
| debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e); | ||
|
|
@@ -303,4 +420,48 @@ export class DecentralizedClient extends Client<"decentralized"> { | |
| } | ||
| return await this.aggregationResult | ||
| } | ||
|
|
||
| /** | ||
| * Receive model from the model provider. | ||
| */ | ||
| private async receiveModel(providerConn: PeerConnection): Promise<WeightsContainer>{ | ||
| const message = await waitMessageWithTimeout(providerConn, type.SharedModel, 30_000, "Timeout while waiting for the latest model") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should make this timeout value a task parameter, depending on the model size it can make sense to increase the timeout |
||
|
|
||
| const decoded = serialization.weights.decode(message.model) | ||
| return decoded | ||
| } | ||
|
|
||
| /** | ||
| * Send the latest available model to a newly joining client. | ||
| * If the current training round is in progress, wait until the round finishes | ||
| * and receive the latest aggregated model. | ||
| */ | ||
| private async sendModel(newcomerConn: PeerConnection, roundFinishedPromise: Promise<void> | undefined): Promise<void> { | ||
| // wait until the round finishes to get the latest model | ||
| if (roundFinishedPromise !== undefined){ | ||
| await roundFinishedPromise | ||
| } | ||
|
|
||
| const model = this.modelWeightAccess?.getModelWeight() | ||
|
|
||
| if (model === undefined){ | ||
| debug("Failed to get the latest model from model provider client") | ||
| return | ||
| } | ||
| const encoded = await serialization.weights.encode(model) | ||
|
|
||
| const message: messages.SharedModel = { | ||
| type: type.SharedModel, | ||
| model: encoded | ||
| } | ||
| newcomerConn.send(message) | ||
| } | ||
|
|
||
| // Resolve the round finished promise and reset related state | ||
| override finishRound(): void{ | ||
| // Mark round as finished so that model synchronization can proceed | ||
| this.#resolveRoundFinished?.() | ||
| this.#roundFinishedPromise = undefined | ||
| this.#resolveRoundFinished = undefined | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.