Welcome to mirror list, hosted at ThFree Co, Russian Federation.

engine.py « engine « osce « torch « dnn - gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 0762c8988a3f70a38ddbefe3a0729d987fa99d6b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from tqdm import tqdm
import sys

def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):

    model.to(device)
    model.train()

    running_loss = 0
    previous_running_loss = 0


    with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:

        for i, batch in enumerate(tepoch):

            # set gradients to zero
            optimizer.zero_grad()


            # push batch to device
            for key in batch:
                batch[key] = batch[key].to(device)

            target = batch['target']

            # calculate model output
            output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])

            # calculate loss
            if isinstance(output, list):
                loss = torch.zeros(1, device=device)
                for y in output:
                    loss = loss + criterion(target, y.squeeze(1))
                loss = loss / len(output)
            else:
                loss = criterion(target, output.squeeze(1))

            # calculate gradients
            loss.backward()

            # update weights
            optimizer.step()

            # update learning rate
            scheduler.step()

            # sparsification
            if hasattr(model, 'sparsifier'):
                model.sparsifier()

            # update running loss
            running_loss += float(loss.cpu())

            # update status bar
            if i % log_interval == 0:
                tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
                previous_running_loss = running_loss


    running_loss /= len(dataloader)

    return running_loss

def evaluate(model, criterion, dataloader, device, log_interval=10):

    model.to(device)
    model.eval()

    running_loss = 0
    previous_running_loss = 0


    with torch.no_grad():
        with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:

            for i, batch in enumerate(tepoch):

                # push batch to device
                for key in batch:
                    batch[key] = batch[key].to(device)

                target = batch['target']

                # calculate model output
                output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])

                # calculate loss
                loss = criterion(target, output.squeeze(1))

                # update running loss
                running_loss += float(loss.cpu())

                # update status bar
                if i % log_interval == 0:
                    tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
                    previous_running_loss = running_loss


        running_loss /= len(dataloader)

        return running_loss