/
losses.py
1027 lines (881 loc) · 40.6 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 The TensorFlow Ranking Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines ranking losses as TF ops.
The losses here are used to learn TF ranking models. It works with listwise
Tensors only.
"""
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import tensorflow as tf
from tensorflow_ranking.python import losses_impl
from tensorflow_ranking.python import utils
class RankingLossKey(object):
"""Ranking loss key strings."""
# Names for the ranking based loss functions.
PAIRWISE_HINGE_LOSS = 'pairwise_hinge_loss'
PAIRWISE_LOGISTIC_LOSS = 'pairwise_logistic_loss'
PAIRWISE_SOFT_ZERO_ONE_LOSS = 'pairwise_soft_zero_one_loss'
PAIRWISE_MSE_LOSS = 'pairwise_mse_loss'
YETI_LOGISTIC_LOSS = 'yeti_logistic_loss'
CIRCLE_LOSS = 'circle_loss'
SOFTMAX_LOSS = 'softmax_loss'
POLY_ONE_SOFTMAX_LOSS = 'poly_one_softmax_loss'
UNIQUE_SOFTMAX_LOSS = 'unique_softmax_loss'
SIGMOID_CROSS_ENTROPY_LOSS = 'sigmoid_cross_entropy_loss'
MEAN_SQUARED_LOSS = 'mean_squared_loss'
LIST_MLE_LOSS = 'list_mle_loss'
APPROX_NDCG_LOSS = 'approx_ndcg_loss'
APPROX_MRR_LOSS = 'approx_mrr_loss'
GUMBEL_APPROX_NDCG_LOSS = 'gumbel_approx_ndcg_loss'
NEURAL_SORT_CROSS_ENTROPY_LOSS = 'neural_sort_cross_entropy_loss'
GUMBEL_NEURAL_SORT_CROSS_ENTROPY_LOSS = 'gumbel_neural_sort_cross_entropy_loss'
NEURAL_SORT_NDCG_LOSS = 'neural_sort_ndcg_loss'
GUMBEL_NEURAL_SORT_NDCG_LOSS = 'gumbel_neural_sort_ndcg_loss'
@classmethod
def all_keys(cls) -> List[str]:
return [v for k, v in vars(cls).items() if k.isupper()]
class _LossFunctionMaker(object):
"""The loss function maker."""
def __init__(
self,
loss_keys: Union[str, Sequence[str]],
loss_weights: Optional[Sequence[Union[float, int]]] = None,
weights_feature_name: Optional[str] = None,
lambda_weight: Optional[losses_impl._LambdaWeight] = None,
reduction: tf.compat.v1.losses.Reduction = tf.compat.v1.losses.Reduction
.SUM_BY_NONZERO_WEIGHTS,
name: Optional[str] = None,
params: Optional[Mapping[str, Any]] = None,
gumbel_params: Optional[Mapping[str, Any]] = None,
):
"""Initializes a loss function maker.
Args:
loss_keys: A string or list of strings representing loss keys defined in
`RankingLossKey`. Listed loss functions will be combined in a weighted
manner, with weights specified by `loss_weights`. If `loss_weights` is
None, default weight of 1 will be used. The loss_keys could also be of
form 'mean_squared_loss:0.1,softmax_loss:0.9' in which case it will be
parsed as a list of loss keys `['mean_squared_loss', 'softmax_loss']`
and corresponding loss weights `[0.1, 0.9]`.
loss_weights: List of weights, same length as `loss_keys`. Used when
merging losses to calculate the weighted sum of losses. If `None`, all
losses are weighted equally with weight being 1.
weights_feature_name: A string specifying the name of the weights feature
in `features` dict.
lambda_weight: A `_LambdaWeight` object created by factory methods like
`create_ndcg_lambda_weight()`.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
params: A string-keyed dictionary that contains any other loss-specific
arguments.
gumbel_params: A string-keyed dictionary that contains other
`gumbel_softmax_sample` arguments.
"""
if isinstance(loss_keys, str) and (':' in loss_keys or ',' in loss_keys):
if loss_weights is not None:
raise ValueError(
'`loss_weights` has to be None when weights are encoded in `loss_keys`.'
)
keys_to_weights = utils.parse_keys_and_weights(loss_keys)
loss_keys = list(keys_to_weights.keys())
loss_weights = list(keys_to_weights.values())
self.loss_keys = loss_keys
self.loss_weights = loss_weights
self.weights_feature_name = weights_feature_name
self.lambda_weight = lambda_weight
self.reduction = reduction
self.name = name
self.params = params or {}
self.gumbel_params = gumbel_params or {}
def build_key_to_loss_fn_mapping(
self, loss_kwargs: Mapping[str, Any],
loss_kwargs_with_lambda_weight: Mapping[str, Any],
gbl_loss_kwargs: Mapping[str, Any]
) -> Dict[str, Tuple[Callable[..., tf.Tensor], Dict[str, Any]]]:
"""Builds the mapping from loss keys to loss functions."""
key_to_fn = {
RankingLossKey.PAIRWISE_HINGE_LOSS:
(_pairwise_hinge_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.PAIRWISE_LOGISTIC_LOSS:
(_pairwise_logistic_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS:
(_pairwise_soft_zero_one_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.PAIRWISE_MSE_LOSS:
(_pairwise_mse_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.YETI_LOGISTIC_LOSS:
(_pairwise_logistic_loss, gbl_loss_kwargs),
RankingLossKey.CIRCLE_LOSS:
(_circle_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.SOFTMAX_LOSS:
(_softmax_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.POLY_ONE_SOFTMAX_LOSS:
(_poly_one_softmax_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.UNIQUE_SOFTMAX_LOSS:
(_unique_softmax_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.SIGMOID_CROSS_ENTROPY_LOSS:
(_sigmoid_cross_entropy_loss, loss_kwargs),
RankingLossKey.MEAN_SQUARED_LOSS: (_mean_squared_loss, loss_kwargs),
RankingLossKey.LIST_MLE_LOSS:
(_list_mle_loss, loss_kwargs_with_lambda_weight),
RankingLossKey.APPROX_NDCG_LOSS: (_approx_ndcg_loss, loss_kwargs),
RankingLossKey.APPROX_MRR_LOSS: (_approx_mrr_loss, loss_kwargs),
RankingLossKey.GUMBEL_APPROX_NDCG_LOSS:
(_approx_ndcg_loss, gbl_loss_kwargs),
RankingLossKey.NEURAL_SORT_CROSS_ENTROPY_LOSS:
(_neural_sort_cross_entropy_loss, loss_kwargs),
RankingLossKey.GUMBEL_NEURAL_SORT_CROSS_ENTROPY_LOSS:
(_neural_sort_cross_entropy_loss, gbl_loss_kwargs),
RankingLossKey.NEURAL_SORT_NDCG_LOSS:
(_neural_sort_ndcg_loss, loss_kwargs),
RankingLossKey.GUMBEL_NEURAL_SORT_NDCG_LOSS:
(_neural_sort_ndcg_loss, gbl_loss_kwargs),
}
return key_to_fn
def make(self) -> utils.LossFunction:
"""Makes the loss function.
Returns:
A function _loss_fn(). See `_loss_fn()` for its signature.
Raises:
ValueError: If `reduction` is invalid.
ValueError: If `loss_keys` is None or empty.
ValueError: If `loss_keys` and `loss_weights` have different sizes.
"""
if (self.reduction not in tf.compat.v1.losses.Reduction.all() or
self.reduction == tf.compat.v1.losses.Reduction.NONE):
raise ValueError(f'Invalid reduction: {self.reduction}')
if not self.loss_keys:
raise ValueError('loss_keys cannot be None or empty.')
if not isinstance(self.loss_keys, list):
self.loss_keys = [self.loss_keys]
if self.loss_weights:
if len(self.loss_keys) != len(self.loss_weights):
raise ValueError('loss_keys and loss_weights must have the same size.')
gumbel_sampler = losses_impl.GumbelSampler(**self.gumbel_params)
def _loss_fn(labels: utils.TensorLike, logits: utils.TensorLike,
features: Dict[str, utils.TensorLike]) -> tf.Tensor:
"""Computes a single loss or weighted combination of losses.
Args:
labels: A `Tensor` of the same shape as `logits` representing relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
features: Dict of Tensors of shape [batch_size, list_size, ...] for
per-example features and shape [batch_size, ...] for non-example
context features.
Returns:
An op for a single loss or weighted combination of multiple losses.
Raises:
ValueError: If `loss_keys` is invalid.
"""
weights = None
if self.weights_feature_name:
weights = tf.convert_to_tensor(
value=features[self.weights_feature_name])
# Convert weights to a 2-D Tensor.
weights = utils.reshape_to_2d(weights)
gbl_labels, gbl_logits, gbl_weights = gumbel_sampler.sample(
labels, logits, weights=weights)
loss_kwargs = {
'labels': labels,
'logits': logits,
'weights': weights,
'reduction': self.reduction,
'name': self.name,
}
gbl_loss_kwargs = {
'labels': gbl_labels,
'logits': gbl_logits,
'weights': gbl_weights,
'reduction': self.reduction,
'name': self.name,
}
loss_kwargs.update(self.params)
gbl_loss_kwargs.update(self.params)
if self.lambda_weight is not None:
gbl_loss_kwargs['lambda_weight'] = self.lambda_weight
loss_kwargs_with_lambda_weight = loss_kwargs.copy()
loss_kwargs_with_lambda_weight['lambda_weight'] = self.lambda_weight
key_to_fn = self.build_key_to_loss_fn_mapping(
loss_kwargs, loss_kwargs_with_lambda_weight, gbl_loss_kwargs)
# Obtain the list of loss ops.
loss_ops = []
for loss_key in self.loss_keys:
if loss_key not in key_to_fn:
raise ValueError(f'Invalid loss_key: {loss_key}.')
loss_fn, kwargs = key_to_fn[loss_key]
loss_ops.append(loss_fn(**kwargs))
# Compute weighted combination of losses.
if self.loss_weights:
weighted_losses = []
for loss_op, loss_weight in zip(loss_ops, self.loss_weights):
weighted_losses.append(tf.multiply(loss_op, loss_weight))
else:
weighted_losses = loss_ops
return tf.add_n(weighted_losses)
return _loss_fn
def make_loss_fn(
loss_keys: Union[str, Sequence[str]],
loss_weights: Optional[Sequence[Union[float, int]]] = None,
weights_feature_name: Optional[str] = None,
lambda_weight: Optional[losses_impl._LambdaWeight] = None,
reduction: tf.compat.v1.losses.Reduction = tf.compat.v1.losses.Reduction
.SUM_BY_NONZERO_WEIGHTS,
name: Optional[str] = None,
params: Optional[Mapping[str, Any]] = None,
gumbel_params: Optional[Mapping[str, Any]] = None,
) -> utils.LossFunction:
"""Makes a loss function using a single loss or multiple losses.
Args:
loss_keys: A string or list of strings representing loss keys defined in
`RankingLossKey`. Listed loss functions will be combined in a weighted
manner, with weights specified by `loss_weights`. If `loss_weights` is
None, default weight of 1 will be used. The loss_keys could also be of
form 'mean_squared_loss:0.1,softmax_loss:0.9' in which case it will be
parsed as a list of loss keys `['mean_squared_loss', 'softmax_loss']`
and corresponding loss weights `[0.1, 0.9]`.
loss_weights: List of weights, same length as `loss_keys`. Used when merging
losses to calculate the weighted sum of losses. If `None`, all losses are
weighted equally with weight being 1.
weights_feature_name: A string specifying the name of the weights feature in
`features` dict.
lambda_weight: A `_LambdaWeight` object created by factory methods like
`create_ndcg_lambda_weight()`.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
params: A string-keyed dictionary that contains any other loss-specific
arguments.
gumbel_params: A string-keyed dictionary that contains other
`gumbel_softmax_sample` arguments.
Returns:
A function _loss_fn(). See `_loss_fn()` for its signature.
Raises:
ValueError: If `reduction` is invalid.
ValueError: If `loss_keys` is None or empty.
ValueError: If `loss_keys` and `loss_weights` have different sizes.
"""
return _LossFunctionMaker(loss_keys, loss_weights, weights_feature_name,
lambda_weight, reduction, name, params,
gumbel_params).make()
class _LossMetricFunctionMaker(object):
"""The loss metric function maker."""
def __init__(self,
loss_key: str,
weights_feature_name: Optional[str] = None,
lambda_weight: Optional[losses_impl._LambdaWeight] = None,
name: Optional[str] = None):
"""Initializes a loss metric function maker.
Args:
loss_key: A key in `RankingLossKey`.
weights_feature_name: A `string` specifying the name of the weights
feature in `features` dict.
lambda_weight: A `_LambdaWeight` object.
name: A `string` used as the name for this metric.
Returns:
A metric fn with the following Args:
* `labels`: A `Tensor` of the same shape as `predictions` representing
graded relevance.
* `predictions`: A `Tensor` with shape [batch_size, list_size]. Each value
is the ranking score of the corresponding example.
* `features`: A dict of `Tensor`s that contains all features.
"""
self.loss_key = loss_key
self.weights_feature_name = weights_feature_name
self.lambda_weight = lambda_weight
self.name = name
def build_key_to_loss_fn_mapping(
self, name: Optional[str],
lambda_weight: Optional[losses_impl._LambdaWeight]
) -> Dict[str, losses_impl._RankingLoss]:
"""Builds the mapping from loss keys to loss functions."""
metric_dict = {
RankingLossKey.PAIRWISE_HINGE_LOSS:
losses_impl.PairwiseHingeLoss(name, lambda_weight=lambda_weight),
RankingLossKey.PAIRWISE_LOGISTIC_LOSS:
losses_impl.PairwiseLogisticLoss(name, lambda_weight=lambda_weight),
RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS:
losses_impl.PairwiseSoftZeroOneLoss(
name, lambda_weight=lambda_weight),
RankingLossKey.PAIRWISE_MSE_LOSS:
losses_impl.PairwiseMSELoss(name, lambda_weight=lambda_weight),
RankingLossKey.CIRCLE_LOSS:
losses_impl.CircleLoss(name),
RankingLossKey.SOFTMAX_LOSS:
losses_impl.SoftmaxLoss(name, lambda_weight=lambda_weight),
RankingLossKey.POLY_ONE_SOFTMAX_LOSS:
losses_impl.PolyOneSoftmaxLoss(name, lambda_weight=lambda_weight),
RankingLossKey.UNIQUE_SOFTMAX_LOSS:
losses_impl.UniqueSoftmaxLoss(name, lambda_weight=lambda_weight),
RankingLossKey.SIGMOID_CROSS_ENTROPY_LOSS:
losses_impl.SigmoidCrossEntropyLoss(name),
RankingLossKey.MEAN_SQUARED_LOSS:
losses_impl.MeanSquaredLoss(name),
RankingLossKey.LIST_MLE_LOSS:
losses_impl.ListMLELoss(name, lambda_weight=lambda_weight),
RankingLossKey.APPROX_NDCG_LOSS:
losses_impl.ApproxNDCGLoss(name),
RankingLossKey.APPROX_MRR_LOSS:
losses_impl.ApproxMRRLoss(name),
RankingLossKey.GUMBEL_APPROX_NDCG_LOSS:
losses_impl.ApproxNDCGLoss(name),
RankingLossKey.NEURAL_SORT_CROSS_ENTROPY_LOSS:
losses_impl.NeuralSortCrossEntropyLoss(name),
RankingLossKey.GUMBEL_NEURAL_SORT_CROSS_ENTROPY_LOSS:
losses_impl.NeuralSortCrossEntropyLoss(name),
RankingLossKey.NEURAL_SORT_NDCG_LOSS:
losses_impl.NeuralSortNDCGLoss(name),
RankingLossKey.GUMBEL_NEURAL_SORT_NDCG_LOSS:
losses_impl.NeuralSortNDCGLoss(name),
}
return metric_dict
def make(self) -> utils.MetricFunction:
"""Makes the loss metric function.
Returns:
A function _metric_fn(). See `_metric_fn()` for its signature.
"""
def _get_weights(features: Dict[str, utils.TensorLike]) -> utils.TensorLike:
"""Gets weights tensor from features and reshape it to 2-D if necessary.
"""
weights = None
if self.weights_feature_name:
weights = tf.convert_to_tensor(
value=features[self.weights_feature_name])
# Convert weights to a 2-D Tensor.
weights = utils.reshape_to_2d(weights)
return weights
def _metric_fn(labels: utils.TensorLike, predictions: utils.TensorLike,
features: Dict[str, utils.TensorLike]) -> tf.Tensor:
"""Defines the metric fn."""
weights = _get_weights(features)
metric_dict = self.build_key_to_loss_fn_mapping(self.name,
self.lambda_weight)
loss = metric_dict.get(self.loss_key, None)
if loss is None:
raise ValueError(f'loss_key {self.loss_key} not supported.')
return loss.eval_metric(labels, predictions, weights)
return _metric_fn
def make_loss_metric_fn(loss_key,
weights_feature_name=None,
lambda_weight=None,
name=None):
"""Creates a metric based on a loss.
Args:
loss_key: A key in `RankingLossKey`.
weights_feature_name: A `string` specifying the name of the weights feature
in `features` dict.
lambda_weight: A `_LambdaWeight` object.
name: A `string` used as the name for this metric.
Returns:
A metric fn with the following Args:
* `labels`: A `Tensor` of the same shape as `predictions` representing
graded relevance.
* `predictions`: A `Tensor` with shape [batch_size, list_size]. Each value
is the ranking score of the corresponding example.
* `features`: A dict of `Tensor`s that contains all features.
"""
return _LossMetricFunctionMaker(loss_key, weights_feature_name, lambda_weight,
name).make()
def create_ndcg_lambda_weight(topn=None, smooth_fraction=0.):
"""Creates _LambdaWeight for NDCG metric."""
return losses_impl.DCGLambdaWeight(
topn,
gain_fn=lambda labels: tf.pow(2.0, labels) - 1.,
rank_discount_fn=lambda rank: 1. / tf.math.log1p(rank),
normalized=True,
smooth_fraction=smooth_fraction)
def create_reciprocal_rank_lambda_weight(topn=None, smooth_fraction=0.):
"""Creates _LambdaWeight for MRR-like metric."""
return losses_impl.DCGLambdaWeight(
topn,
gain_fn=lambda labels: labels,
rank_discount_fn=lambda rank: 1. / rank,
normalized=True,
smooth_fraction=smooth_fraction)
def create_p_list_mle_lambda_weight(list_size):
"""Creates _LambdaWeight based on Position-Aware ListMLE paper.
Produces a weight based on the formulation presented in the
"Position-Aware ListMLE" paper (Lan et al.) and available using
create_p_list_mle_lambda_weight() factory function above.
Args:
list_size: Size of the input list.
Returns:
A _LambdaWeight for Position-Aware ListMLE.
"""
return losses_impl.ListMLELambdaWeight(
rank_discount_fn=lambda rank: tf.pow(2., list_size - rank) - 1.)
def _pairwise_hinge_loss(
labels,
logits,
weights=None,
lambda_weight=None,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
name=None):
"""Computes the pairwise hinge loss for a list.
The hinge loss is defined as Hinge(l_i > l_j) = max(0, 1 - (s_i - s_j)). So a
correctly ordered pair has 0 loss if (s_i - s_j >= 1). Otherwise the loss
increases linearly with s_i - s_j. When the list_size is 2, this reduces to
the standard hinge loss.
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights.
lambda_weight: A `_LambdaWeight` object.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
Returns:
An op for the pairwise hinge loss.
"""
loss = losses_impl.PairwiseHingeLoss(name, lambda_weight)
with tf.compat.v1.name_scope(loss.name, 'pairwise_hinge_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _pairwise_logistic_loss(
labels,
logits,
weights=None,
lambda_weight=None,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
name=None):
"""Computes the pairwise logistic loss for a list.
The preference probability of each pair is computed as the sigmoid function:
`P(l_i > l_j) = 1 / (1 + exp(s_j - s_i))` and the logistic loss is
`-log(P(l_i > l_j))` if `l_i > l_j` and `0` otherwise.
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights.
lambda_weight: A `_LambdaWeight` object.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
Returns:
An op for the pairwise logistic loss.
"""
loss = losses_impl.PairwiseLogisticLoss(name, lambda_weight)
with tf.compat.v1.name_scope(loss.name, 'pairwise_logistic_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _pairwise_soft_zero_one_loss(
labels,
logits,
weights=None,
lambda_weight=None,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
name=None):
"""Computes the pairwise soft zero-one loss.
Note this is different from sigmoid cross entropy in that soft zero-one loss
is a smooth but non-convex approximation of zero-one loss. The preference
probability of each pair is computed as the sigmoid function: P(l_i > l_j) = 1
/ (1 + exp(s_j - s_i)). Then 1 - P(l_i > l_j) is directly used as the loss.
So a correctly ordered pair has a loss close to 0, while an incorrectly
ordered pair has a loss bounded by 1.
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights.
lambda_weight: A `_LambdaWeight` object.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
Returns:
An op for the pairwise soft zero one loss.
"""
loss = losses_impl.PairwiseSoftZeroOneLoss(name, lambda_weight)
with tf.compat.v1.name_scope(loss.name, 'pairwise_soft_zero_one_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _pairwise_mse_loss(
labels,
logits,
weights=None,
lambda_weight=None,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
name=None):
"""Computes the pairwise MSE loss.
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights.
lambda_weight: A `_LambdaWeight` object.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
Returns:
An op for the pairwise soft zero one loss.
"""
loss = losses_impl.PairwiseMSELoss(name, lambda_weight)
with tf.compat.v1.name_scope(loss.name, 'pairwise_mse_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _circle_loss(
labels,
logits,
weights=None,
lambda_weight=None,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
name=None,
gamma=64,
margin=0.25):
"""Computes the pairwise circle loss for a list.
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights.
lambda_weight: A `_LambdaWeight` object.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
gamma: A float parameter used in circle loss.
margin: A float parameter defining the margin in circle loss.
Returns:
An op for the pairwise logistic loss.
"""
loss = losses_impl.CircleLoss(name, lambda_weight, gamma, margin)
with tf.compat.v1.name_scope(loss.name, 'pairwise_circle_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _softmax_loss(
labels,
logits,
weights=None,
lambda_weight=None,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
name=None):
"""Computes the softmax cross entropy for a list.
This is the ListNet loss originally proposed by Cao et al.
["Learning to Rank: From Pairwise Approach to Listwise Approach"] and is
appropriate for datasets with binary relevance labels [see "An Analysis of
the Softmax Cross Entropy Loss for Learning-to-Rank with Binary Relevance" by
Bruch et al.]
Given the labels l_i and the logits s_i, we sort the examples and obtain ranks
r_i. The standard softmax loss doesn't need r_i and is defined as
-sum_i l_i * log(exp(s_i) / (exp(s_1) + ... + exp(s_n))).
The `lambda_weight` re-weight examples based on l_i and r_i.
-sum_i w(l_i, r_i) * log(exp(s_i) / (exp(s_1) + ... + exp(s_n))).abc
See 'individual_weights' in 'DCGLambdaWeight' for how w(l_i, r_i) is computed.
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights.
lambda_weight: A `DCGLambdaWeight` instance.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
Returns:
An op for the softmax cross entropy as a loss.
"""
loss = losses_impl.SoftmaxLoss(name, lambda_weight)
with tf.compat.v1.name_scope(loss.name, 'softmax_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _poly_one_softmax_loss(
labels,
logits,
weights=None,
lambda_weight=None,
epsilon=1.0,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
name=None):
"""Computes the poly1 softmax cross entropy for a list.
This is the loss originally proposed by Leng et al.
["PolyLoss: A Polynomial Expansion Perspective of Classification Loss
Functions". ICLR 2022].
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights.
lambda_weight: A `DCGLambdaWeight` instance.
epsilon: A scalar, controlling contribution of the first polynomial.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
Returns:
An op for the poly1 softmax cross entropy as a loss.
"""
loss = losses_impl.PolyOneSoftmaxLoss(name, lambda_weight, epsilon)
with tf.compat.v1.name_scope(loss.name, 'poly_one_softmax_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _unique_softmax_loss(
labels,
logits,
weights=None,
lambda_weight=None,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
name=None):
"""Computes the unique rating softmax cross entropy for a list.
This is the uRank loss originally proposed by Zhu and Klabjan in
["Listwise Learning to Rank by Exploring Unique Ratings"] and is
appropriate for datasets with multiple relevance labels.
Given the labels l_i and the logits s_i, the unique softmax loss is defined as
-sum_i (2^l_i - 1) * log(exp(s_i) / (sum_j exp(s_j) + exp(s_i))),
where j is over the documents with l_j < l_i.
TODO: Add the lambda_weight support.
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights.
lambda_weight: A `DCGLambdaWeight` instance.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
Returns:
An op for the softmax cross entropy as a loss.
"""
loss = losses_impl.UniqueSoftmaxLoss(name, lambda_weight)
with tf.compat.v1.name_scope(loss.name, 'unique_softmax_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _sigmoid_cross_entropy_loss(
labels,
logits,
weights=None,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
name=None):
"""Computes the sigmoid_cross_entropy loss for a list.
Given the labels of graded relevance l_i and the logits s_i, we calculate
the sigmoid cross entropy for each ith position and aggregate the per position
losses.
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
Returns:
An op for the sigmoid cross entropy as a loss.
"""
loss = losses_impl.SigmoidCrossEntropyLoss(name)
with tf.compat.v1.name_scope(loss.name, 'sigmoid_cross_entropy_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _mean_squared_loss(
labels,
logits,
weights=None,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
name=None):
"""Computes the mean squared loss for a list.
Given the labels of graded relevance l_i and the logits s_i, we calculate
the squared error for each ith position and aggregate the per position
losses.
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
Returns:
An op for the mean squared error as a loss.
"""
loss = losses_impl.MeanSquaredLoss(name)
with tf.compat.v1.name_scope(loss.name, 'mean_squared_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _list_mle_loss(
labels,
logits,
weights=None,
lambda_weight=None,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
name=None):
"""Computes the ListMLE loss in (Xia et al 2008) for a list.
Given the labels of graded relevance l_i and the logits s_i, we calculate
the ListMLE loss for the given list.
The `lambda_weight` re-weights examples based on l_i and r_i.
The recommended weighting scheme is the formulation presented in the
"Position-Aware ListMLE" paper (Lan et al.) and available using
create_p_list_mle_lambda_weight() factory function above.
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights.
lambda_weight: A `DCGLambdaWeight` instance.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
Returns:
An op for the ListMLE loss.
"""
loss = losses_impl.ListMLELoss(name, lambda_weight)
with tf.compat.v1.name_scope(loss.name, 'list_mle_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _approx_ndcg_loss(labels,
logits,
weights=None,
reduction=tf.compat.v1.losses.Reduction.SUM,
name=None,
temperature=0.1):
"""Computes ApproxNDCG loss.
ApproxNDCG ["A general approximation framework for direct optimization of
information retrieval measures" by Qin et al.] is a smooth approximation
to NDCG. Its performance on datasets with graded relevance is competitive
with other state-of-the-art algorithms [see "Revisiting Approximate Metric
Optimization in the Age of Deep Neural Networks" by Bruch et al.].
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights. If None, the weight of a list in the mini-batch is set to the sum
of the labels of the items in that list.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
temperature: The temperature used to scale logits=logits/temperature.
Returns:
An op for the ApproxNDCG loss.
"""
loss = losses_impl.ApproxNDCGLoss(name, temperature=temperature)
with tf.compat.v1.name_scope(loss.name, 'approx_ndcg_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _approx_mrr_loss(labels,
logits,
weights=None,
reduction=tf.compat.v1.losses.Reduction.SUM,
name=None,
temperature=0.1):
"""Computes ApproxMRR loss.
ApproxMRR ["A general approximation framework for direct optimization of
information retrieval measures" by Qin et al.] is a smooth approximation
to MRR.
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights. If None, the weight of a list in the mini-batch is set to the sum
of the labels of the items in that list.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
temperature: The temperature used to scale logits=logits/temperature.
Returns:
An op for the ApproxMRR loss.
"""
loss = losses_impl.ApproxMRRLoss(name, temperature=temperature)
with tf.compat.v1.name_scope(loss.name, 'approx_mrr_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _neural_sort_cross_entropy_loss(labels,
logits,
weights=None,
reduction=tf.compat.v1.losses.Reduction.SUM,
name=None,
temperature=1.0):
"""Computes NeuralSort CrossEntropy loss.
NeuralSort CrossEntropy computes the cross entropy of the permutation
matrix approximations between the one computed from labels and the one
computed from the logits.
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights. If None, the weight of a list in the mini-batch is set to the sum
of the labels of the items in that list.
reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch.
name: A string used as the name for this loss.
temperature: The temperature used to scale logits=logits/temperature.
Returns:
An op for the NeuralSort CrossEntropy loss.
"""
loss = losses_impl.NeuralSortCrossEntropyLoss(name, temperature=temperature)
with tf.compat.v1.name_scope(loss.name, 'neural_sort_cross_entropy_loss',
(labels, logits, weights)):
return loss.compute(labels, logits, weights, reduction)
def _neural_sort_ndcg_loss(labels,
logits,
weights=None,
reduction=tf.compat.v1.losses.Reduction.SUM,
name=None,