Skip to content

Commit

Permalink
ability for extension protocols to be able to respond viaa PWC, and a…
Browse files Browse the repository at this point in the history
…bility to announce and detect the presence of DHT support
  • Loading branch information
bizzehdee committed Apr 13, 2014
1 parent 39b0406 commit 0aed363
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 52 deletions.
37 changes: 33 additions & 4 deletions Extensions/UTPeerExchange.cs
Expand Up @@ -103,9 +103,38 @@ public void OnExtendedMessage(PeerWireClient peerWireClient, byte[] bytes)
}
}

public void SendMessage(PeerWireClient peerWireClient, IPEndPoint[] addedEndPoints, byte[] flags)
{

}
public void SendMessage(PeerWireClient peerWireClient, IPEndPoint[] addedEndPoints, byte[] flags, IPEndPoint[] droppedEndPoints)
{
if (addedEndPoints == null && droppedEndPoints == null) return;

BDict d = new BDict();

if (addedEndPoints != null)
{
byte[] added = new byte[addedEndPoints.Length * 6];
for (int x = 0; x < addedEndPoints.Length; x++)
{
addedEndPoints[x].Address.GetAddressBytes().CopyTo(added, x * 6);
BitConverter.GetBytes((ushort)addedEndPoints[x].Port).CopyTo(added, (x * 6)+4);
}

d.Add("added", new BString { ByteValue = added });
}

if (droppedEndPoints != null)
{
byte[] dropped = new byte[droppedEndPoints.Length * 6];
for (int x = 0; x < droppedEndPoints.Length; x++)
{
droppedEndPoints[x].Address.GetAddressBytes().CopyTo(dropped, x * 6);

dropped.SetValue((ushort)droppedEndPoints[x].Port, (x * 6) + 2);
}

d.Add("dropped", new BString { ByteValue = dropped });
}

peerWireClient.SendExtended(peerWireClient.GetOutgoingMessageID(this), BencodingUtils.EncodeBytes(d));
}
}
}
155 changes: 107 additions & 48 deletions PeerWireClient.cs
Expand Up @@ -46,16 +46,18 @@ public class PeerWireClient
internal readonly Socket Socket;
private byte[] _internalBuffer; //async internal buffer
private readonly List<IBTExtension> _protocolExtensions;
private readonly Dictionary<String, Int64> _extOutgoing = new Dictionary<string, long>();
private readonly Dictionary<Int64, String> _extIncoming = new Dictionary<Int64, String>();
private readonly Dictionary<String, byte> _extOutgoing = new Dictionary<string, byte>();
private readonly Dictionary<byte, String> _extIncoming = new Dictionary<byte, String>();

public Int32 Timeout { get; private set; }
public bool[] PeerBitField { get; set; }
public bool KeepConnectionAlive { get; set; }
public bool UseExtended { get; set; }
public bool UseFast { get; set; }
public bool UseDHT { get; set; }
public bool RemoteUsesExtended { get; private set; }
public bool RemoteUsesFast { get; private set; }
public bool UseFast { get; set; }
public bool RemoteUsesDHT { get; private set; }
public String LocalPeerID { get; set; }
public String RemotePeerID { get; private set; }
public String Hash { get; set; }
Expand Down Expand Up @@ -83,11 +85,9 @@ public PeerWireClient(Int32 timeout)

Timeout = timeout;

Socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)
{
ReceiveTimeout = timeout*1000,
SendTimeout = timeout*1000
};
Socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
Socket.ReceiveTimeout = timeout*1000;
Socket.SendTimeout = timeout*1000;

_internalBuffer = new byte[0];
}
Expand Down Expand Up @@ -118,7 +118,7 @@ public void Connect(String ipHost, Int32 port)
public void Disconnect()
{
Socket.Disconnect(false);
Socket.Close();
//Socket.Close();
}

public void Handshake()
Expand All @@ -143,16 +143,17 @@ public void Handshake(byte[] hash, byte[] peerId)
if (peerId.Length != 20) throw new ArgumentOutOfRangeException("peerId", "Peer ID must be 20 bytes exactly");

byte[] reservedBytes = {0, 0, 0, 0, 0, 0, 0, 0};
if(UseExtended) reservedBytes[5] |= 0x10;
if(UseFast) reservedBytes[7] |= 0x04;
reservedBytes[5] |= (byte)(UseExtended ? 0x10 : 0x00);
reservedBytes[7] |= (byte)(UseFast ? 0x04 : 0x00);
reservedBytes[7] |= (byte)(UseDHT ? 0x1 : 0x00);

byte[] sendBuf = (new[] { (byte)_bitTorrentProtocolHeader.Length }).Concat(_bitTorrentProtocolHeader).Concat(reservedBytes).Concat(hash).Concat(peerId).ToArray();

if (UseExtended)
{
BDict handshakeDict = new BDict();
BDict mDict = new BDict();
Int32 i = 1;
byte i = 1;
foreach (IBTExtension extension in _protocolExtensions)
{
_extOutgoing.Add(extension.Protocol, i);
Expand Down Expand Up @@ -196,14 +197,20 @@ public void Handshake(byte[] hash, byte[] peerId)
Int32 resLen = readBuf[0];
if (resLen != 19)
{
Socket.Disconnect(false);
Socket.Close();
throw new InvalidProgramException("Invalid response received from peer");
if (resLen == 0)
{
// keep alive?
Thread.Sleep(100);

Disconnect();
return;
}
}

byte[] recReserved = readBuf.Skip(20).Take(8).ToArray();
RemoteUsesExtended = (recReserved[5] & 0x10) == 0x10;
RemoteUsesFast = (recReserved[7] & 0x04) == 0x04;
RemoteUsesDHT = (recReserved[7] & 0x1) == 0x1;

byte[] recBuffer = new byte[128];
Socket.BeginReceive(recBuffer, 0, 128, SocketFlags.None, OnReceived, recBuffer);
Expand Down Expand Up @@ -274,7 +281,10 @@ public void SendBitField(bool[] bitField, bool obsf)
int x = (int)Math.Floor((double)i/8);
ushort p = (ushort) (i%8);

if(bitField[i]) bytes[x] = bytes[x].SetBit(p);
if (bitField[i])
{
bytes[x] = bytes[x].SetBit(p);
}
}

Socket.Send(Pack.Int32(1 + bitField.Length, Pack.Endianness.Big).Concat(new byte[] { 5 }).Concat(bytes).ToArray());
Expand Down Expand Up @@ -303,9 +313,11 @@ public void SendCancel(Int32 index, Int32 start, Int32 length)
Socket.Send(Pack.Int32(13, Pack.Endianness.Big).Concat(new byte[] { 8 }).Concat(Pack.Int32(index)).Concat(Pack.Int32(start)).Concat(Pack.Int32(length)).ToArray());
}

public void SendExtended(Int32 extMsgId, Int32 start, Int32 length)
public void SendExtended(byte extMsgId, byte[] bytes)
{

Int32 length = 2 + bytes.Length;

Socket.Send(Pack.Int32(length, Pack.Endianness.Big).Concat(new [] { (byte)20} ).Concat(new [] { extMsgId }).Concat(bytes).ToArray());
}

public void OnReceived(IAsyncResult ar)
Expand All @@ -322,21 +334,17 @@ public void OnReceived(IAsyncResult ar)
}

byte[] recBuffer = new byte[128];
if (Socket.Connected) Socket.BeginReceive(recBuffer, 0, 128, SocketFlags.None, OnReceived, recBuffer);

if (Socket.Connected)
{
Socket.BeginReceive(recBuffer, 0, 128, SocketFlags.None, OnReceived, recBuffer);
}
}

public bool Process()
{
Thread.Sleep(10);

/*if (_socket.Connected && _socket.Available > 0)
{
byte[] recBuffer = new byte[_socket.Available];
_socket.Receive(recBuffer);
_internalBuffer = _internalBuffer == null ? recBuffer : _internalBuffer.Concat(recBuffer).ToArray();
}*/

if (_internalBuffer.Length < 4)
{
if (!Socket.Connected) return false;
Expand Down Expand Up @@ -560,10 +568,16 @@ private void ProcessReject()
private void ProcessExtended(Int32 length)
{
Int32 msgId = _internalBuffer[0];
lock (_locker) _internalBuffer = _internalBuffer.Skip(1).ToArray();
lock (_locker)
{
_internalBuffer = _internalBuffer.Skip(1).ToArray();
}

byte[] buffer = _internalBuffer.Take(length-1).ToArray();
lock (_locker) _internalBuffer = _internalBuffer.Skip(length - 1).ToArray();
lock (_locker)
{
_internalBuffer = _internalBuffer.Skip(length - 1).ToArray();
}

if (msgId == 0)
{
Expand All @@ -573,7 +587,7 @@ private void ProcessExtended(Int32 length)
foreach (KeyValuePair<string, IBencodingType> pair in mDict)
{
BInt i = (BInt)pair.Value;
_extIncoming.Add(i, pair.Key);
_extIncoming.Add((byte)i, pair.Key);

IBTExtension ext = _protocolExtensions.FirstOrDefault(f => f.Protocol == pair.Key);

Expand All @@ -585,7 +599,7 @@ private void ProcessExtended(Int32 length)
}
else
{
KeyValuePair<Int64, String> pair = _extIncoming.FirstOrDefault(f => f.Key == msgId);
KeyValuePair<byte, String> pair = _extIncoming.FirstOrDefault(f => f.Key == msgId);
IBTExtension ext = _protocolExtensions.FirstOrDefault(f => f.Protocol == pair.Value);

if (ext != null)
Expand All @@ -599,7 +613,10 @@ private void ProcessAllowFast()
{
Int32 index = Unpack.Int32(_internalBuffer, 0, Unpack.Endianness.Big);

lock (_locker) _internalBuffer = _internalBuffer.Skip(4).ToArray();
lock (_locker)
{
_internalBuffer = _internalBuffer.Skip(4).ToArray();
}

OnAllowFast(index);
}
Expand All @@ -610,52 +627,82 @@ private void ProcessAllowFast()

private void OnKeepAlive()
{
if (KeepAlive != null) KeepAlive(this);
if (KeepAlive != null)
{
KeepAlive(this);
}
}

private void OnChoke()
{
if (Choke != null) Choke(this);
if (Choke != null)
{
Choke(this);
}
}

private void OnUnChoke()
{
if (UnChoke != null) UnChoke(this);
if (UnChoke != null)
{
UnChoke(this);
}
}

private void OnInterested()
{
if (Interested != null) Interested(this);
if (Interested != null)
{
Interested(this);
}
}

private void OnNotInterested()
{
if (NotInterested != null) NotInterested(this);
if (NotInterested != null)
{
NotInterested(this);
}
}

private void OnHave(Int32 pieceIndex)
{
if (Have != null) Have(this, pieceIndex);
if (Have != null)
{
Have(this, pieceIndex);
}
}

private void OnBitField(Int32 size, bool[] bitField)
{
if (BitField != null) BitField(this, size, bitField);
if (BitField != null)
{
BitField(this, size, bitField);
}
}

private void OnRequest(Int32 index, Int32 begin, Int32 length)
{
if (Request != null) Request(this, index, begin, length);
if (Request != null)
{
Request(this, index, begin, length);
}
}

private void OnPiece(Int32 index, Int32 begin, byte[] bytes)
{
if (Piece != null) Piece(this, index, begin, bytes);
if (Piece != null)
{
Piece(this, index, begin, bytes);
}
}

private void OnCancel(Int32 index, Int32 begin, Int32 length)
{
if (Cancel != null) Cancel(this, index, begin, length);
if (Cancel != null)
{
Cancel(this, index, begin, length);
}
}

private void OnPort(UInt16 port)
Expand All @@ -676,22 +723,34 @@ private void OnSuggest(Int32 pieceIndex)

private void OnHaveAll()
{
if (HaveAll != null) HaveAll(this);
if (HaveAll != null)
{
HaveAll(this);
}
}

private void OnHaveNone()
{
if (HaveNone != null) HaveNone(this);
if (HaveNone != null)
{
HaveNone(this);
}
}

private void OnReject(Int32 index, Int32 begin, Int32 length)
{
if (Reject != null) Reject(this, index, begin, length);
if (Reject != null)
{
Reject(this, index, begin, length);
}
}

private void OnAllowFast(Int32 pieceIndex)
{
if (AllowedFast != null) AllowedFast(this, pieceIndex);
if (AllowedFast != null)
{
AllowedFast(this, pieceIndex);
}
}
#endregion

Expand All @@ -707,14 +766,14 @@ public void UnregisterProtocolExtension(IBTExtension extension)
extension.Deinit(this);
}

public Int64 GetOutgoingMessageID(IBTExtension extension)
public byte GetOutgoingMessageID(IBTExtension extension)
{
if (_extOutgoing.ContainsKey(extension.Protocol))
{
return _extOutgoing[extension.Protocol];
}

return -1;
return 0;
}
}
}

0 comments on commit 0aed363

Please sign in to comment.