+183
-10
lines changedFilter options
+183
-10
lines changed Original file line number Diff line number Diff line change
@@ -146,6 +146,61 @@ static bool PyUpb_DescriptorPool_TryLoadFilename(PyUpb_DescriptorPool* self,
146
146
return ret;
147
147
}
148
148
149
+
static bool PyUpb_DescriptorPool_TryLoadExtension(PyUpb_DescriptorPool* self,
150
+
const upb_MessageDef* m,
151
+
int32_t field_number) {
152
+
if (!self->db) return false;
153
+
const char* full_name = upb_MessageDef_FullName(m);
154
+
PyObject* py_name = PyUnicode_FromStringAndSize(full_name, strlen(full_name));
155
+
PyObject* py_descriptor = PyObject_CallMethod(
156
+
self->db, "FindFileContainingExtension", "Oi", py_name, field_number);
157
+
Py_DECREF(py_name);
158
+
if (!py_descriptor) {
159
+
PyErr_Clear();
160
+
return false;
161
+
}
162
+
bool ret = PyUpb_DescriptorPool_TryLoadFileProto(self, py_descriptor);
163
+
Py_DECREF(py_descriptor);
164
+
return ret;
165
+
}
166
+
167
+
static void PyUpb_DescriptorPool_TryLoadAllExtensions(
168
+
PyUpb_DescriptorPool* self, const upb_MessageDef* m) {
169
+
if (!self->db) return;
170
+
const char* full_name = upb_MessageDef_FullName(m);
171
+
PyObject* py_name = PyUnicode_FromStringAndSize(full_name, strlen(full_name));
172
+
PyObject* py_list =
173
+
PyObject_CallMethod(self->db, "FindAllExtensionNumbers", "O", py_name);
174
+
Py_DECREF(py_name);
175
+
if (!py_list) {
176
+
PyErr_Clear();
177
+
return;
178
+
}
179
+
Py_ssize_t size = PyList_Size(py_list);
180
+
if (size == -1) {
181
+
PyErr_Format(
182
+
PyExc_RuntimeError,
183
+
"FindAllExtensionNumbers() on fall back DB must return a list, not %S",
184
+
py_list);
185
+
PyErr_Print();
186
+
Py_DECREF(py_list);
187
+
return;
188
+
}
189
+
int64_t field_number;
190
+
const upb_ExtensionRegistry* reg =
191
+
upb_DefPool_ExtensionRegistry(self->symtab);
192
+
const upb_MiniTable* t = upb_MessageDef_MiniTable(m);
193
+
for (Py_ssize_t i = 0; i < size; ++i) {
194
+
PyObject* item = PySequence_GetItem(py_list, i);
195
+
field_number = PyLong_AsLong(item);
196
+
Py_DECREF(item);
197
+
if (!upb_ExtensionRegistry_Lookup(reg, t, field_number)) {
198
+
PyUpb_DescriptorPool_TryLoadExtension(self, m, field_number);
199
+
}
200
+
}
201
+
Py_DECREF(py_list);
202
+
}
203
+
149
204
bool PyUpb_DescriptorPool_CheckNoDatabase(PyObject* _self) { return true; }
150
205
151
206
static bool PyUpb_DescriptorPool_LoadDependentFiles(
@@ -582,8 +637,15 @@ static PyObject* PyUpb_DescriptorPool_FindExtensionByNumber(PyObject* _self,
582
637
return NULL;
583
638
}
584
639
585
-
const upb_FieldDef* f = upb_DefPool_FindExtensionByNumber(
586
-
self->symtab, PyUpb_Descriptor_GetDef(message_descriptor), number);
640
+
const upb_MessageDef* message_def =
641
+
PyUpb_Descriptor_GetDef(message_descriptor);
642
+
const upb_FieldDef* f =
643
+
upb_DefPool_FindExtensionByNumber(self->symtab, message_def, number);
644
+
if (f == NULL && self->db) {
645
+
if (PyUpb_DescriptorPool_TryLoadExtension(self, message_def, number)) {
646
+
f = upb_DefPool_FindExtensionByNumber(self->symtab, message_def, number);
647
+
}
648
+
}
587
649
if (f == NULL) {
588
650
return PyErr_Format(PyExc_KeyError, "Couldn't find Extension %d", number);
589
651
}
@@ -595,6 +657,9 @@ static PyObject* PyUpb_DescriptorPool_FindAllExtensions(PyObject* _self,
595
657
PyObject* msg_desc) {
596
658
PyUpb_DescriptorPool* self = (PyUpb_DescriptorPool*)_self;
597
659
const upb_MessageDef* m = PyUpb_Descriptor_GetDef(msg_desc);
660
+
if (self->db) {
661
+
PyUpb_DescriptorPool_TryLoadAllExtensions(self, m);
662
+
}
598
663
size_t n;
599
664
const upb_FieldDef** ext = upb_DefPool_GetAllExtensions(self->symtab, m, &n);
600
665
PyObject* ret = PyList_New(n);
Original file line number Diff line number Diff line change
@@ -580,11 +580,21 @@ def FindAllExtensions(self, message_descriptor):
580
580
if self._descriptor_db and hasattr(
581
581
self._descriptor_db, 'FindAllExtensionNumbers'):
582
582
full_name = message_descriptor.full_name
583
-
all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
584
-
for number in all_numbers:
585
-
if number in self._extensions_by_number[message_descriptor]:
586
-
continue
587
-
self._TryLoadExtensionFromDB(message_descriptor, number)
583
+
try:
584
+
all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
585
+
except:
586
+
pass
587
+
else:
588
+
if isinstance(all_numbers, list):
589
+
for number in all_numbers:
590
+
if number in self._extensions_by_number[message_descriptor]:
591
+
continue
592
+
self._TryLoadExtensionFromDB(message_descriptor, number)
593
+
else:
594
+
warnings.warn(
595
+
'FindAllExtensionNumbers() on fall back DB must return a list,'
596
+
' not {0}'.format(type(all_numbers))
597
+
)
588
598
589
599
return list(self._extensions_by_number[message_descriptor].values())
590
600
@@ -603,8 +613,13 @@ def _TryLoadExtensionFromDB(self, message_descriptor, number):
603
613
return
604
614
605
615
full_name = message_descriptor.full_name
606
-
file_proto = self._descriptor_db.FindFileContainingExtension(
607
-
full_name, number)
616
+
file_proto = None
617
+
try:
618
+
file_proto = self._descriptor_db.FindFileContainingExtension(
619
+
full_name, number
620
+
)
621
+
except:
622
+
return
608
623
609
624
if file_proto is None:
610
625
return
@@ -668,7 +683,6 @@ def SetFeatureSetDefaults(self, defaults):
668
683
if not isinstance(defaults, descriptor_pb2.FeatureSetDefaults):
669
684
raise TypeError('SetFeatureSetDefaults called with invalid type')
670
685
671
-
672
686
if defaults.minimum_edition > defaults.maximum_edition:
673
687
raise ValueError(
674
688
'Invalid edition range %s to %s'
Original file line number Diff line number Diff line change
@@ -1531,5 +1531,91 @@ def testAboveMaximum(self):
1531
1531
'google/protobuf/internal/more_messages.proto'])
1532
1532
1533
1533
1534
+
class LocalFakeDB(descriptor_database.DescriptorDatabase):
1535
+
1536
+
def FindFileContainingExtension(self, extendee_name, extension_number):
1537
+
return descriptor_pb2.FileDescriptorProto.FromString(
1538
+
factory_test2_pb2.DESCRIPTOR.serialized_pb
1539
+
)
1540
+
1541
+
def FindAllExtensionNumbers(self, extendee_name):
1542
+
return [1001, 1002]
1543
+
1544
+
1545
+
class BadDB(descriptor_database.DescriptorDatabase):
1546
+
1547
+
def FindFileContainingExtension(self, extendee_name, extension_number):
1548
+
raise RuntimeError('just ignore it')
1549
+
1550
+
def FindAllExtensionNumbers(self, extendee_name):
1551
+
raise RuntimeError('just ignore it')
1552
+
1553
+
1554
+
class BadDB2(descriptor_database.DescriptorDatabase):
1555
+
1556
+
# Returns a none list value.
1557
+
def FindAllExtensionNumbers(self, extendee_name):
1558
+
return 1.2
1559
+
1560
+
1561
+
@testing_refleaks.TestCase
1562
+
class FallBackDBTest(unittest.TestCase):
1563
+
1564
+
def setUp(self):
1565
+
self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
1566
+
factory_test1_pb2.DESCRIPTOR.serialized_pb
1567
+
)
1568
+
factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
1569
+
factory_test2_pb2.DESCRIPTOR.serialized_pb
1570
+
)
1571
+
db = LocalFakeDB()
1572
+
db.Add(self.factory_test1_fd)
1573
+
db.Add(factory_test2_fd)
1574
+
self.pool = descriptor_pool.DescriptorPool(db)
1575
+
file_desc = self.pool.FindFileByName(
1576
+
'google/protobuf/internal/factory_test1.proto'
1577
+
)
1578
+
self.message_desc = file_desc.message_types_by_name['Factory1Message']
1579
+
1580
+
bad_db = BadDB()
1581
+
bad_db.Add(self.factory_test1_fd)
1582
+
self.bad_pool = descriptor_pool.DescriptorPool(bad_db)
1583
+
1584
+
def testFindExtensionByNumber(self):
1585
+
ext = self.pool.FindExtensionByNumber(self.message_desc, 1001)
1586
+
self.assertEqual(ext.name, 'one_more_field')
1587
+
1588
+
def testFindAllExtensions(self):
1589
+
extensions = self.pool.FindAllExtensions(self.message_desc)
1590
+
self.assertEqual(len(extensions), 2)
1591
+
1592
+
def testIgnoreBadFindExtensionByNumber(self):
1593
+
file_desc = self.bad_pool.FindFileByName(
1594
+
'google/protobuf/internal/factory_test1.proto'
1595
+
)
1596
+
message_desc = file_desc.message_types_by_name['Factory1Message']
1597
+
with self.assertRaises(KeyError):
1598
+
ext = self.bad_pool.FindExtensionByNumber(message_desc, 1001)
1599
+
1600
+
def testIgnoreBadFindAllExtensions(self):
1601
+
file_desc = self.bad_pool.FindFileByName(
1602
+
'google/protobuf/internal/factory_test1.proto'
1603
+
)
1604
+
message_desc = file_desc.message_types_by_name['Factory1Message']
1605
+
extensions = self.bad_pool.FindAllExtensions(message_desc)
1606
+
self.assertEqual(len(extensions), 0)
1607
+
1608
+
def testFindAllExtensionsReturnsNoneList(self):
1609
+
db = BadDB2()
1610
+
db.Add(self.factory_test1_fd)
1611
+
pool = descriptor_pool.DescriptorPool(db)
1612
+
file_desc = pool.FindFileByName(
1613
+
'google/protobuf/internal/factory_test1.proto'
1614
+
)
1615
+
message_desc = file_desc.message_types_by_name['Factory1Message']
1616
+
extensions = self.bad_pool.FindAllExtensions(message_desc)
1617
+
self.assertEqual(len(extensions), 0)
1618
+
1619
+
1534
1620
if __name__ == '__main__':
1535
1621
unittest.main()
Original file line number Diff line number Diff line change
@@ -146,6 +146,14 @@ bool PyDescriptorDatabase::FindAllExtensionNumbers(
146
146
return false;
147
147
}
148
148
Py_ssize_t size = PyList_Size(py_list.get());
149
+
if (size == -1) {
150
+
PyErr_Format(
151
+
PyExc_RuntimeError,
152
+
"FindAllExtensionNumbers() on fall back DB must return a list, not %S",
153
+
py_list.get());
154
+
PyErr_Print();
155
+
return false;
156
+
}
149
157
int64_t item_value;
150
158
for (Py_ssize_t i = 0 ; i < size; ++i) {
151
159
ScopedPyObjectPtr item(PySequence_GetItem(py_list.get(), i));
You can’t perform that action at this time.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4