A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/tensorflow/tensorflow/commit/207d50d253e11c3a3430a700af478a1d524a779a below:

Add a static limit to folding binary ops to avoid constant size explo… · tensorflow/tensorflow@207d50d · GitHub

@@ -94,8 +94,18 @@ namespace {

94 94

// expand a single element splat to a multi-GB large tensor.

95 95

// The limit is arbitrary set low to allow expanding small computations, like

96 96

// shape manipulations for example.

97 +

// TODO(b/210478841): Define a constant folding policy that generalizes this.

97 98

constexpr int64_t kFoldExpandSplatEltLimit = 16;

98 99 100 +

// Similarly to the constant above, this is an arbitrary limit into how many

101 +

// elements can be folded by a binary operation folder.

102 +

// This limit doesn't apply to the following special cases:

103 +

// 1) Adding a zero.

104 +

// 2) Multiplying by one.

105 +

// 3) When both operands are splats.

106 +

// TODO(b/210478841): Define a constant folding policy that generalizes this.

107 +

constexpr int64_t kFoldBinaryOpEltLimit = 65536;

108 + 99 109

// Clamps value to the range [lower, upper]. Requires lower <= upper.

100 110

template <typename T>

101 111

static T Clamp(const T& value, const T& lower, const T& upper) {

@@ -5234,6 +5244,22 @@ static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {

5234 5244

return {};

5235 5245

}

5236 5246 5247 +

// Special case for folding splats no matter how large.

5248 +

// Only covers the case of both attrs being splats; operation-specific cases

5249 +

// like adding a zero or multiplying by one are handled elsewhere.

5250 +

SplatElementsAttr splat_lhs = lhs.dyn_cast<SplatElementsAttr>();

5251 +

SplatElementsAttr splat_rhs = rhs.dyn_cast<SplatElementsAttr>();

5252 +

if (splat_lhs && splat_rhs) {

5253 +

return SplatElementsAttr::get(

5254 +

type, Convert()(splat_lhs.getSplatValue<ValType>(),

5255 +

splat_rhs.getSplatValue<ValType>()));

5256 +

}

5257 + 5258 +

// Prevent folding if lhs/rhs are too large.

5259 +

if (lhs.getNumElements() > kFoldBinaryOpEltLimit) {

5260 +

return {};

5261 +

}

5262 + 5237 5263

SmallVector<ValType, 6> values;

5238 5264

values.reserve(lhs.getNumElements());

5239 5265

for (const auto zip :

@@ -5315,40 +5341,60 @@ BINARY_FOLDER(RemOp, remainder);

5315 5341

BINARY_FOLDER(MaxOp, max);

5316 5342

BINARY_FOLDER(MinOp, min);

5317 5343 5318 -

OpFoldResult AddOp::fold(ArrayRef<Attribute> attrs) {

5319 -

if (attrs[0] && attrs[1]) {

5320 -

BINARY_FOLDER_INTERNAL(AddOp, std::plus)

5344 +

bool isSplatZero(SplatElementsAttr attr) {

5345 +

if (!attr) return false;

5346 +

if (attr.getElementType().isa<FloatType>()) {

5347 +

return attr.getSplatValue<APFloat>().isZero();

5348 +

} else if (attr.getElementType().isa<IntegerType>()) {

5349 +

return attr.getSplatValue<APInt>().isZero();

5350 +

} else {

5351 +

return false;

5321 5352

}

5353 +

}

5354 + 5355 +

OpFoldResult AddOp::fold(ArrayRef<Attribute> attrs) {

5322 5356

// Handle special case where one operand is 0: x + 0 => x

5323 5357

if (attrs[0] || attrs[1]) {

5324 -

SplatElementsAttr attr = attrs[0] ? attrs[0].dyn_cast<SplatElementsAttr>()

5325 -

: attrs[1].dyn_cast<SplatElementsAttr>();

5326 -

if (!attr) return {};

5327 -

Value result = attrs[0] ? rhs() : lhs();

5328 -

if (attr.getElementType().isa<FloatType>()) {

5329 -

if (attr.getSplatValue<APFloat>().isZero()) return result;

5330 -

} else if (attr.getElementType().isa<IntegerType>()) {

5331 -

if (attr.getSplatValue<APInt>().isZero()) return result;

5332 -

}

5358 +

SplatElementsAttr splat_lhs =

5359 +

attrs[0].dyn_cast_or_null<SplatElementsAttr>();

5360 +

SplatElementsAttr splat_rhs =

5361 +

attrs[1].dyn_cast_or_null<SplatElementsAttr>();

5362 +

if (isSplatZero(splat_lhs))

5363 +

return splat_rhs ? (OpFoldResult)splat_rhs : rhs();

5364 +

if (isSplatZero(splat_rhs))

5365 +

return splat_lhs ? (OpFoldResult)splat_lhs : lhs();

5366 +

}

5367 +

if (attrs[0] && attrs[1]) {

5368 +

BINARY_FOLDER_INTERNAL(AddOp, std::plus)

5333 5369

}

5334 5370

return {};

5335 5371

}

5336 5372 5337 -

OpFoldResult MulOp::fold(ArrayRef<Attribute> attrs) {

5338 -

if (attrs[0] && attrs[1]) {

5339 -

BINARY_FOLDER_INTERNAL(MulOp, std::multiplies);

5373 +

bool isSplatOne(SplatElementsAttr attr) {

5374 +

if (!attr) return false;

5375 +

if (attr.getElementType().isa<FloatType>()) {

5376 +

return attr.getSplatValue<APFloat>().convertToDouble() == 1.0;

5377 +

} else if (attr.getElementType().isa<IntegerType>()) {

5378 +

return attr.getSplatValue<APInt>().getSExtValue() == 1;

5379 +

} else {

5380 +

return false;

5340 5381

}

5382 +

}

5383 + 5384 +

OpFoldResult MulOp::fold(ArrayRef<Attribute> attrs) {

5341 5385

// Handle special case where one operand is 1: x * 1 => x

5342 5386

if (attrs[0] || attrs[1]) {

5343 -

SplatElementsAttr attr = attrs[0] ? attrs[0].dyn_cast<SplatElementsAttr>()

5344 -

: attrs[1].dyn_cast<SplatElementsAttr>();

5345 -

if (!attr) return {};

5346 -

Value result = attrs[0] ? rhs() : lhs();

5347 -

if (attr.getElementType().isa<FloatType>()) {

5348 -

if (attr.getSplatValue<APFloat>().convertToDouble() == 1.0) return result;

5349 -

} else if (attr.getElementType().isa<IntegerType>()) {

5350 -

if (attr.getSplatValue<APInt>().getSExtValue() == 1) return result;

5351 -

}

5387 +

SplatElementsAttr splat_lhs =

5388 +

attrs[0].dyn_cast_or_null<SplatElementsAttr>();

5389 +

SplatElementsAttr splat_rhs =

5390 +

attrs[1].dyn_cast_or_null<SplatElementsAttr>();

5391 +

if (isSplatOne(splat_lhs))

5392 +

return splat_rhs ? (OpFoldResult)splat_rhs : rhs();

5393 +

if (isSplatOne(splat_rhs))

5394 +

return splat_lhs ? (OpFoldResult)splat_lhs : lhs();

5395 +

}

5396 +

if (attrs[0] && attrs[1]) {

5397 +

BINARY_FOLDER_INTERNAL(MulOp, std::multiplies);

5352 5398

}

5353 5399

return {};

5354 5400

}


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4