Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: m-labs/artiq
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 668928a16c31
Choose a base ref
...
head repository: m-labs/artiq
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: eb8d630148c5
Choose a head ref
  • 7 commits
  • 5 files changed
  • 1 contributor

Commits on Dec 5, 2016

  1. Copy the full SHA
    74fe5c3 View commit details
  2. rtio: DMA fixes

    sbourdeauducq committed Dec 5, 2016
    Copy the full SHA
    30bce5a View commit details
  3. Copy the full SHA
    a583476 View commit details
  4. rtio: DMA unittest WIP

    sbourdeauducq committed Dec 5, 2016
    Copy the full SHA
    43a5455 View commit details
  5. Copy the full SHA
    75ea137 View commit details
  6. Copy the full SHA
    b677c69 View commit details
  7. Copy the full SHA
    eb8d630 View commit details
129 changes: 89 additions & 40 deletions artiq/gateware/rtio/dma.py
Original file line number Diff line number Diff line change
@@ -52,8 +52,6 @@ def __init__(self, membus, enable):
# All numbers in bytes
self.base_address = CSRStorage(aw + data_alignment,
alignment_bits=data_alignment)
self.last_address = CSRStorage(aw + data_alignment,
alignment_bits=data_alignment)

# # #

@@ -64,16 +62,14 @@ def __init__(self, membus, enable):
If(enable & ~enable_r,
address.address.eq(self.base_address.storage),
address.eop.eq(0),
address.stb.eq(1)
address.stb.eq(1),
),
If(address.stb & address.ack,
If(address.eop,
address.stb.eq(0)
).Else(
address.address.eq(address.address + 1),
If(~enable | (address.address == self.last_address.storage),
address.eop.eq(1)
)
If(~enable, address.eop.eq(1))
)
)
]
@@ -87,14 +83,16 @@ def __init__(self, in_size, out_size, granularity):
self.source = Signal(out_size*g)
self.source_stb = Signal()
self.source_consume = Signal(max=out_size+1)
self.flush = Signal()
self.flush_done = Signal()

# # #

# worst-case buffer space required (when loading):
# <data being shifted out> <new incoming word> <EOP marker>
buf_size = out_size - 1 + in_size + 1
# <data being shifted out> <new incoming word>
buf_size = out_size - 1 + in_size
buf = Signal(buf_size*g)
self.comb += self.source.eq(buf[:out_size])
self.comb += self.source.eq(buf[:out_size*8])

level = Signal(max=buf_size+1)
next_level = Signal(max=buf_size+1)
@@ -106,9 +104,7 @@ def __init__(self, in_size, out_size, granularity):

self.sync += [
If(load_buf, Case(level,
# note how the MSBs of the buffer are set to 0
# (including the EOP marker position)
{i: buf[i*g:].eq(self.sink.data)
{i: buf[i*g:(i+in_size)*g].eq(self.sink.data)
for i in range(out_size)})),
If(shift_buf, buf.eq(buf >> self.source_consume*g))
]
@@ -120,23 +116,28 @@ def __init__(self, in_size, out_size, granularity):
self.sink.ack.eq(1),
load_buf.eq(1),
If(self.sink.stb,
If(self.sink.eop,
# insert <granularity> bits of 0 to mark EOP
next_level.eq(level + in_size + 1)
).Else(
next_level.eq(level + in_size)
)
next_level.eq(level + in_size)
),
If(next_level >= out_size, NextState("OUTPUT"))
)
fsm.act("OUTPUT",
self.source_stb.eq(1),
shift_buf.eq(1),
next_level.eq(level - self.source_consume),
If(next_level < out_size, NextState("FETCH"))
If(next_level < out_size, NextState("FETCH")),
If(self.flush, NextState("FLUSH"))
)
fsm.act("FLUSH",
next_level.eq(0),
self.sink.ack.eq(1),
If(self.sink.stb & self.sink.eop,
self.flush_done.eq(1),
NextState("FETCH")
)
)


# end marker is a record with length=0
record_layout = [
("length", 8), # of whole record (header+data)
("channel", 24),
@@ -149,36 +150,64 @@ def __init__(self, in_size, out_size, granularity):
class RecordConverter(Module):
def __init__(self, stream_slicer):
self.source = stream.Endpoint(record_layout)
self.end_marker_found = Signal()
self.flush = Signal()

hdrlen = layout_len(record_layout) - 512
hdrlen = (layout_len(record_layout) - 512)//8
record_raw = Record(record_layout)
self.comb += [
record_raw.raw_bits().eq(stream_slicer.source),

self.source.channel.eq(record_raw.channel),
self.source.timestamp.eq(record_raw.timestamp),
self.source.address.eq(record_raw.address),
Case(record_raw.length,
{hdrlen+i*8: self.source.data.eq(record_raw.data[:])
for i in range(512//8)}),
{hdrlen+i: self.source.data.eq(record_raw.data[:i*8])
for i in range(1, 512//8+1)}),
]

self.source.stb.eq(stream_slicer.source_stb),
self.source.eop.eq(record_raw.length == 0),
If(self.source.ack,
fsm = FSM(reset_state="FLOWING")
self.submodules += fsm

fsm.act("FLOWING",
If(stream_slicer.source_stb,
If(record_raw.length == 0,
stream_slicer.source_consume.eq(1)
NextState("END_MARKER_FOUND")
).Else(
stream_slicer.source_consume.eq(record_raw.length)
self.source.stb.eq(1)
)
),
If(self.source.ack,
stream_slicer.source_consume.eq(record_raw.length)
)
]
)
fsm.act("END_MARKER_FOUND",
self.end_marker_found.eq(1),
If(self.flush,
stream_slicer.flush.eq(1),
NextState("WAIT_FLUSH")
)
)
fsm.act("WAIT_FLUSH",
If(stream_slicer.flush_done,
NextState("SEND_EOP")
)
)
fsm.act("SEND_EOP",
self.source.eop.eq(1),
self.source.stb.eq(1),
If(self.source.ack, NextState("FLOWING"))
)


class RecordSlicer(Module):
def __init__(self, in_size):
self.submodules.raw_slicer = RawSlicer(
in_size, layout_len(record_layout)//8, 8)
self.submodules.raw_slicer = ResetInserter()(RawSlicer(
in_size//8, layout_len(record_layout)//8, 8))
self.submodules.record_converter = RecordConverter(self.raw_slicer)

self.end_marker_found = self.record_converter.end_marker_found
self.flush = self.record_converter.flush

self.sink = self.raw_slicer.sink
self.source = self.record_converter.source

@@ -198,6 +227,7 @@ def __init__(self):
leave_out={"timestamp"}),
self.source.payload.timestamp.eq(self.sink.payload.timestamp
+ self.time_offset.storage),
self.source.eop.eq(self.sink.eop),
self.source.stb.eq(self.sink.stb)
)
self.comb += [
@@ -251,14 +281,22 @@ def __init__(self):
self.cri.chan_sel.eq(self.sink.channel),
self.cri.o_timestamp.eq(self.sink.timestamp),
self.cri.o_address.eq(self.sink.address),
self.cri.o_data.eq(self.sink.data)
]

fsm = FSM(reset_state="IDLE")
self.submodules += fsm

fsm.act("IDLE",
If(self.error_status.status == 0,
If(self.sink.stb, NextState("WRITE"))
If(self.sink.stb,
If(self.sink.eop,
# last packet contains dummy data, discard it
self.sink.ack.eq(1)
).Else(
NextState("WRITE")
)
)
).Else(
# discard all data until errors are acked
self.sink.ack.eq(1)
@@ -271,7 +309,7 @@ def __init__(self):
)
fsm.act("CHECK_STATE",
self.busy.eq(1),
If(~self.cri.o_status,
If(self.cri.o_status == 0,
self.sink.ack.eq(1),
NextState("IDLE")
),
@@ -293,11 +331,10 @@ def __init__(self):

class DMA(Module):
def __init__(self, membus):
# shutdown procedure: set enable to 0, wait until busy=0
self.enable = CSRStorage()
self.busy = CSRStatus()
self.enable = CSR()

self.submodules.dma = DMAReader(membus, self.enable.storage)
flow_enable = Signal()
self.submodules.dma = DMAReader(membus, flow_enable)
self.submodules.slicer = RecordSlicer(len(membus.dat_w))
self.submodules.time_offset = TimeOffset()
self.submodules.cri_master = CRIMaster()
@@ -313,16 +350,28 @@ def __init__(self, membus):
self.submodules += fsm

fsm.act("IDLE",
If(self.enable.storage, NextState("FLOWING"))
If(self.enable.re & self.enable.r, NextState("FLOWING"))
)
fsm.act("FLOWING",
self.busy.status.eq(1),
self.enable.w.eq(1),
flow_enable.eq(1),
If(self.slicer.end_marker_found | (self.enable.re & ~self.enable.r),
NextState("FLUSH")
)
)
fsm.act("FLUSH",
self.enable.w.eq(1),
self.slicer.flush.eq(1),
NextState("WAIT_EOP")
)
fsm.act("WAIT_EOP",
self.enable.w.eq(1),
If(self.cri_master.sink.stb & self.cri_master.sink.ack & self.cri_master.sink.eop,
NextState("WAIT_CRI_MASTER")
)
)
fsm.act("WAIT_CRI_MASTER",
self.busy.status.eq(1),
self.enable.w.eq(1),
If(~self.cri_master.busy, NextState("IDLE"))
)

Empty file added artiq/test/gateware/__init__.py
Empty file.
Empty file.
Empty file.
96 changes: 96 additions & 0 deletions artiq/test/gateware/rtio/test_dma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import unittest
import random

from migen import *
from misoc.interconnect import wishbone

from artiq.gateware.rtio import dma, cri


def encode_n(n, min_length, max_length):
r = []
while n:
r.append(n & 0xff)
n >>= 8
r += [0]*(min_length - len(r))
if len(r) > max_length:
raise ValueError
return r


def encode_record(channel, timestamp, address, data):
r = []
r += encode_n(channel, 3, 3)
r += encode_n(timestamp, 8, 8)
r += encode_n(address, 2, 2)
r += encode_n(data, 1, 64)
return encode_n(len(r)+1, 1, 1) + r


def pack(x, size):
r = []
for i in range((len(x)+size-1)//size):
n = 0
for j, w in enumerate(x[i*size:(i+1)*size]):
n |= w << j*8
r.append(n)
return r


test_writes = [
(0x01, 0x23, 0x12, 0x33),
(0x901, 0x902, 0x911, 0xeeeeeeeeeeeeeefffffffffffffffffffffffffffffff28888177772736646717738388488),
(0x81, 0x288, 0x88, 0x8888)
]


class TB(Module):
def __init__(self, ws):
sequence = [b for write in test_writes for b in encode_record(*write)]
sequence.append(0)
sequence = pack(sequence, ws)

bus = wishbone.Interface(ws*8)
self.submodules.memory = wishbone.SRAM(
1024, init=sequence, bus=bus)
self.submodules.dut = dma.DMA(bus)


class TestDMA(unittest.TestCase):
def test_dma_noerror(self):
prng = random.Random(0)
ws = 64
tb = TB(ws)

def do_dma():
for i in range(2):
yield from tb.dut.enable.write(1)
yield
while ((yield from tb.dut.enable.read())):
yield

received = []
@passive
def rtio_sim():
dut_cri = tb.dut.cri
while True:
cmd = yield dut_cri.cmd
if cmd == cri.commands["nop"]:
pass
elif cmd == cri.commands["write"]:
channel = yield dut_cri.chan_sel
timestamp = yield dut_cri.o_timestamp
address = yield dut_cri.o_address
data = yield dut_cri.o_data
received.append((channel, timestamp, address, data))

yield dut_cri.o_status.eq(1)
for i in range(prng.randrange(10)):
yield
yield dut_cri.o_status.eq(0)
else:
self.fail("unexpected RTIO command")
yield

run_simulation(tb, [do_dma(), rtio_sim()])
self.assertEqual(received, test_writes*2)