diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3f6600e..881b1fd 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,12 @@ Release History dev --- +**API Changes (Backward Compatible)** + +- GoAwayFrame and WindowUpdateFrame now correctly mask off the reserved bit during + parsing and serialization of stream IDs and window increments, as per RFC 9113, + Sections 6.8 and 6.9. + **API Changes (Backward Incompatible)** - diff --git a/src/hyperframe/frame.py b/src/hyperframe/frame.py index a67487e..2807d68 100644 --- a/src/hyperframe/frame.py +++ b/src/hyperframe/frame.py @@ -132,7 +132,7 @@ def parse_frame_header(header: memoryview, strict: bool = False) -> tuple[Frame, length = (fields[0] << 8) + fields[1] typ_e = fields[2] flags = fields[3] - stream_id = fields[4] & 0x7FFFFFFF + stream_id = fields[4] & 0x7FFFFFFF # mask off the reserved bit, RFC 9113, Section 4.1 try: frame = FRAMES[typ_e](stream_id) @@ -172,7 +172,7 @@ def serialize(self) -> bytes: self.body_len & 0xFF, self.type, flags, - self.stream_id & 0x7FFFFFFF, # Stream ID is 32 bits. + self.stream_id & 0x7FFFFFFF, # mask off the reserved bit, RFC 9113, Section 4.1 ) return header + body @@ -271,7 +271,7 @@ def parse_priority_data(self, data: memoryview) -> int: raise InvalidFrameError(msg) from err self.exclusive = bool(self.depends_on >> 31) - self.depends_on &= 0x7FFFFFFF + self.depends_on &= 0x7FFFFFFF # mask off the exclusive bit, RFC 9113, Section 6.3 return 5 @@ -620,7 +620,7 @@ def _body_repr(self) -> str: def serialize_body(self) -> bytes: data = _STRUCT_LL.pack( - self.last_stream_id & 0x7FFFFFFF, + self.last_stream_id & 0x7FFFFFFF, # mask off the reserved bit, RFC 9113, Section 6.8 self.error_code, ) data += self.additional_data @@ -636,6 +636,8 @@ def parse_body(self, data: memoryview) -> None: msg = "Invalid GOAWAY body." raise InvalidFrameError(msg) from err + # mask off the reserved bit, RFC 9113, Section 6.8 + self.last_stream_id = self.last_stream_id & 0x7FFFFFFF self.body_len = len(data) if len(data) > 8: @@ -674,7 +676,9 @@ def _body_repr(self) -> str: return f"window_increment={self.window_increment}" def serialize_body(self) -> bytes: - return _STRUCT_L.pack(self.window_increment & 0x7FFFFFFF) + return _STRUCT_L.pack( + self.window_increment & 0x7FFFFFFF, # mask off the reserved bit, RFC 9113, Section 6.9 + ) def parse_body(self, data: memoryview) -> None: if len(data) > 4: @@ -687,6 +691,9 @@ def parse_body(self, data: memoryview) -> None: msg = "Invalid WINDOW_UPDATE body" raise InvalidFrameError(msg) from err + # mask off the reserved bit, RFC 9113, Section 6.9 + self.window_increment = self.window_increment & 0x7FFFFFFF + if not 1 <= self.window_increment <= 2**31-1: msg = "WINDOW_UPDATE increment must be between 1 to 2^31-1" raise InvalidDataError(msg) @@ -904,7 +911,7 @@ def serialize(self) -> bytes: self.body_len & 0xFF, self.type, flags, - self.stream_id & 0x7FFFFFFF, # Stream ID is 32 bits. + self.stream_id & 0x7FFFFFFF, # mask off the reserved bit, RFC 9113, Section 4.1 ) return header + self.body diff --git a/tests/test_frames.py b/tests/test_frames.py index 10f2300..87f66d5 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -647,6 +647,35 @@ def test_short_goaway_frame_errors(self): with pytest.raises(InvalidFrameError): decode_frame(s) + def test_goaway_frame_with_reserved_bit_set_parses_properly(self): + s = ( + b'\x00\x00\x0D\x07\x00\x00\x00\x00\x00' + # Frame header + b'\x80\x00\x00\x40' + # Last Stream ID with reserved bit set + b'\x00\x00\x00\x20' + # Error Code + b'hello' # Additional data + ) + f = decode_frame(s) + + assert isinstance(f, GoAwayFrame) + assert f.flags == set() + assert f.additional_data == b'hello' + assert f.body_len == 13 + assert f.last_stream_id == 64 + + def test_goaway_frame_with_reserved_bit_set_serializes_properly(self): + f = GoAwayFrame() + f.last_stream_id = 64 + f.error_code = 32 + f.additional_data = b'hello' + + s = f.serialize() + assert s == ( + b'\x00\x00\x0D\x07\x00\x00\x00\x00\x00' + # Frame header + b'\x00\x00\x00\x40' + # Last Stream ID + b'\x00\x00\x00\x20' + # Error Code + b'hello' # Additional data + ) + class TestWindowUpdateFrame: def test_repr(self): @@ -694,6 +723,22 @@ def test_short_windowupdate_frame_errors(self): with pytest.raises(InvalidDataError): decode_frame(WindowUpdateFrame(2**31).serialize()) + def test_window_update_frame_with_reserved_bit_set_parses_properly(self): + s = b'\x00\x00\x04\x08\x00\x00\x00\x00\x80\x00\x00\x02\x00' + f = decode_frame(s) + + assert isinstance(f, WindowUpdateFrame) + assert f.flags == set() + assert f.window_increment == 512 + assert f.body_len == 4 + + def test_window_update_frame_with_reserved_bit_set_serializes_properly(self): + f = WindowUpdateFrame(0) + f.window_increment = 512 + + s = f.serialize() + assert s == b'\x00\x00\x04\x08\x00\x00\x00\x00\x00\x00\x00\x02\x00' + class TestHeadersFrame: def test_repr(self):